ai-content-maker/.venv/Lib/site-packages/transformers/models/realm/retrieval_realm.py

165 lines
6.2 KiB
Python

# coding=utf-8
# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""REALM Retriever model implementation."""
import os
from typing import Optional, Union
import numpy as np
from huggingface_hub import hf_hub_download
from ... import AutoTokenizer
from ...utils import logging
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
logger = logging.get_logger(__name__)
def convert_tfrecord_to_np(block_records_path: str, num_block_records: int) -> np.ndarray:
import tensorflow.compat.v1 as tf
blocks_dataset = tf.data.TFRecordDataset(block_records_path, buffer_size=512 * 1024 * 1024)
blocks_dataset = blocks_dataset.batch(num_block_records, drop_remainder=True)
np_record = next(blocks_dataset.take(1).as_numpy_iterator())
return np_record
class ScaNNSearcher:
"""Note that ScaNNSearcher cannot currently be used within the model. In future versions, it might however be included."""
def __init__(
self,
db,
num_neighbors,
dimensions_per_block=2,
num_leaves=1000,
num_leaves_to_search=100,
training_sample_size=100000,
):
"""Build scann searcher."""
from scann.scann_ops.py.scann_ops_pybind import builder as Builder
builder = Builder(db=db, num_neighbors=num_neighbors, distance_measure="dot_product")
builder = builder.tree(
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size
)
builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
self.searcher = builder.build()
def search_batched(self, question_projection):
retrieved_block_ids, _ = self.searcher.search_batched(question_projection.detach().cpu())
return retrieved_block_ids.astype("int64")
class RealmRetriever:
"""The retriever of REALM outputting the retrieved evidence block and whether the block has answers as well as answer
positions."
Parameters:
block_records (`np.ndarray`):
A numpy array which cantains evidence texts.
tokenizer ([`RealmTokenizer`]):
The tokenizer to encode retrieved texts.
"""
def __init__(self, block_records, tokenizer):
super().__init__()
self.block_records = block_records
self.tokenizer = tokenizer
def __call__(self, retrieved_block_ids, question_input_ids, answer_ids, max_length=None, return_tensors="pt"):
retrieved_blocks = np.take(self.block_records, indices=retrieved_block_ids, axis=0)
question = self.tokenizer.decode(question_input_ids[0], skip_special_tokens=True)
text = []
text_pair = []
for retrieved_block in retrieved_blocks:
text.append(question)
text_pair.append(retrieved_block.decode())
concat_inputs = self.tokenizer(
text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
)
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
if answer_ids is not None:
return self.block_has_answer(concat_inputs, answer_ids) + (concat_inputs_tensors,)
else:
return (None, None, None, concat_inputs_tensors)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *init_inputs, **kwargs):
if os.path.isdir(pretrained_model_name_or_path):
block_records_path = os.path.join(pretrained_model_name_or_path, _REALM_BLOCK_RECORDS_FILENAME)
else:
block_records_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=_REALM_BLOCK_RECORDS_FILENAME, **kwargs
)
block_records = np.load(block_records_path, allow_pickle=True)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
return cls(block_records, tokenizer)
def save_pretrained(self, save_directory):
# save block records
np.save(os.path.join(save_directory, _REALM_BLOCK_RECORDS_FILENAME), self.block_records)
# save tokenizer
self.tokenizer.save_pretrained(save_directory)
def block_has_answer(self, concat_inputs, answer_ids):
"""check if retrieved_blocks has answers."""
has_answers = []
start_pos = []
end_pos = []
max_answers = 0
for input_id in concat_inputs.input_ids:
input_id_list = input_id.tolist()
# Check answers between two [SEP] tokens
first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id)
start_pos.append([])
end_pos.append([])
for answer in answer_ids:
for idx in range(first_sep_idx + 1, second_sep_idx):
if answer[0] == input_id_list[idx]:
if input_id_list[idx : idx + len(answer)] == answer:
start_pos[-1].append(idx)
end_pos[-1].append(idx + len(answer) - 1)
if len(start_pos[-1]) == 0:
has_answers.append(False)
else:
has_answers.append(True)
if len(start_pos[-1]) > max_answers:
max_answers = len(start_pos[-1])
# Pad -1 to max_answers
for start_pos_, end_pos_ in zip(start_pos, end_pos):
if len(start_pos_) < max_answers:
padded = [-1] * (max_answers - len(start_pos_))
start_pos_ += padded
end_pos_ += padded
return has_answers, start_pos, end_pos