165 lines
6.2 KiB
Python
165 lines
6.2 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""REALM Retriever model implementation."""
|
||
|
|
||
|
import os
|
||
|
from typing import Optional, Union
|
||
|
|
||
|
import numpy as np
|
||
|
from huggingface_hub import hf_hub_download
|
||
|
|
||
|
from ... import AutoTokenizer
|
||
|
from ...utils import logging
|
||
|
|
||
|
|
||
|
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
|
||
|
def convert_tfrecord_to_np(block_records_path: str, num_block_records: int) -> np.ndarray:
|
||
|
import tensorflow.compat.v1 as tf
|
||
|
|
||
|
blocks_dataset = tf.data.TFRecordDataset(block_records_path, buffer_size=512 * 1024 * 1024)
|
||
|
blocks_dataset = blocks_dataset.batch(num_block_records, drop_remainder=True)
|
||
|
np_record = next(blocks_dataset.take(1).as_numpy_iterator())
|
||
|
|
||
|
return np_record
|
||
|
|
||
|
|
||
|
class ScaNNSearcher:
|
||
|
"""Note that ScaNNSearcher cannot currently be used within the model. In future versions, it might however be included."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
db,
|
||
|
num_neighbors,
|
||
|
dimensions_per_block=2,
|
||
|
num_leaves=1000,
|
||
|
num_leaves_to_search=100,
|
||
|
training_sample_size=100000,
|
||
|
):
|
||
|
"""Build scann searcher."""
|
||
|
|
||
|
from scann.scann_ops.py.scann_ops_pybind import builder as Builder
|
||
|
|
||
|
builder = Builder(db=db, num_neighbors=num_neighbors, distance_measure="dot_product")
|
||
|
builder = builder.tree(
|
||
|
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size
|
||
|
)
|
||
|
builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
|
||
|
|
||
|
self.searcher = builder.build()
|
||
|
|
||
|
def search_batched(self, question_projection):
|
||
|
retrieved_block_ids, _ = self.searcher.search_batched(question_projection.detach().cpu())
|
||
|
return retrieved_block_ids.astype("int64")
|
||
|
|
||
|
|
||
|
class RealmRetriever:
|
||
|
"""The retriever of REALM outputting the retrieved evidence block and whether the block has answers as well as answer
|
||
|
positions."
|
||
|
|
||
|
Parameters:
|
||
|
block_records (`np.ndarray`):
|
||
|
A numpy array which cantains evidence texts.
|
||
|
tokenizer ([`RealmTokenizer`]):
|
||
|
The tokenizer to encode retrieved texts.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, block_records, tokenizer):
|
||
|
super().__init__()
|
||
|
self.block_records = block_records
|
||
|
self.tokenizer = tokenizer
|
||
|
|
||
|
def __call__(self, retrieved_block_ids, question_input_ids, answer_ids, max_length=None, return_tensors="pt"):
|
||
|
retrieved_blocks = np.take(self.block_records, indices=retrieved_block_ids, axis=0)
|
||
|
|
||
|
question = self.tokenizer.decode(question_input_ids[0], skip_special_tokens=True)
|
||
|
|
||
|
text = []
|
||
|
text_pair = []
|
||
|
for retrieved_block in retrieved_blocks:
|
||
|
text.append(question)
|
||
|
text_pair.append(retrieved_block.decode())
|
||
|
|
||
|
concat_inputs = self.tokenizer(
|
||
|
text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
|
||
|
)
|
||
|
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
|
||
|
|
||
|
if answer_ids is not None:
|
||
|
return self.block_has_answer(concat_inputs, answer_ids) + (concat_inputs_tensors,)
|
||
|
else:
|
||
|
return (None, None, None, concat_inputs_tensors)
|
||
|
|
||
|
@classmethod
|
||
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *init_inputs, **kwargs):
|
||
|
if os.path.isdir(pretrained_model_name_or_path):
|
||
|
block_records_path = os.path.join(pretrained_model_name_or_path, _REALM_BLOCK_RECORDS_FILENAME)
|
||
|
else:
|
||
|
block_records_path = hf_hub_download(
|
||
|
repo_id=pretrained_model_name_or_path, filename=_REALM_BLOCK_RECORDS_FILENAME, **kwargs
|
||
|
)
|
||
|
block_records = np.load(block_records_path, allow_pickle=True)
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
|
||
|
|
||
|
return cls(block_records, tokenizer)
|
||
|
|
||
|
def save_pretrained(self, save_directory):
|
||
|
# save block records
|
||
|
np.save(os.path.join(save_directory, _REALM_BLOCK_RECORDS_FILENAME), self.block_records)
|
||
|
# save tokenizer
|
||
|
self.tokenizer.save_pretrained(save_directory)
|
||
|
|
||
|
def block_has_answer(self, concat_inputs, answer_ids):
|
||
|
"""check if retrieved_blocks has answers."""
|
||
|
has_answers = []
|
||
|
start_pos = []
|
||
|
end_pos = []
|
||
|
max_answers = 0
|
||
|
|
||
|
for input_id in concat_inputs.input_ids:
|
||
|
input_id_list = input_id.tolist()
|
||
|
# Check answers between two [SEP] tokens
|
||
|
first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
|
||
|
second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id)
|
||
|
|
||
|
start_pos.append([])
|
||
|
end_pos.append([])
|
||
|
for answer in answer_ids:
|
||
|
for idx in range(first_sep_idx + 1, second_sep_idx):
|
||
|
if answer[0] == input_id_list[idx]:
|
||
|
if input_id_list[idx : idx + len(answer)] == answer:
|
||
|
start_pos[-1].append(idx)
|
||
|
end_pos[-1].append(idx + len(answer) - 1)
|
||
|
|
||
|
if len(start_pos[-1]) == 0:
|
||
|
has_answers.append(False)
|
||
|
else:
|
||
|
has_answers.append(True)
|
||
|
if len(start_pos[-1]) > max_answers:
|
||
|
max_answers = len(start_pos[-1])
|
||
|
|
||
|
# Pad -1 to max_answers
|
||
|
for start_pos_, end_pos_ in zip(start_pos, end_pos):
|
||
|
if len(start_pos_) < max_answers:
|
||
|
padded = [-1] * (max_answers - len(start_pos_))
|
||
|
start_pos_ += padded
|
||
|
end_pos_ += padded
|
||
|
return has_answers, start_pos, end_pos
|