315 lines
13 KiB
Python
315 lines
13 KiB
Python
# coding=utf-8
|
|
# Copyright 2021 VinAI Research and the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License
|
|
""" Tokenization classes for BARTpho-syllable model."""
|
|
|
|
|
|
import os
|
|
from shutil import copyfile
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import sentencepiece as spm
|
|
|
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
|
from ...utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
SPIECE_UNDERLINE = "▁"
|
|
|
|
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "monolingual_vocab_file": "dict.txt"}
|
|
|
|
|
|
class BartphoTokenizer(PreTrainedTokenizer):
|
|
"""
|
|
Adapted from [`XLMRobertaTokenizer`]. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
|
|
|
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
|
this superclass for more information regarding those methods.
|
|
|
|
Args:
|
|
vocab_file (`str`):
|
|
Path to the vocabulary file. This vocabulary is the pre-trained SentencePiece model available from the
|
|
multilingual XLM-RoBERTa, also used in mBART, consisting of 250K types.
|
|
monolingual_vocab_file (`str`):
|
|
Path to the monolingual vocabulary file. This monolingual vocabulary consists of Vietnamese-specialized
|
|
types extracted from the multilingual vocabulary vocab_file of 250K types.
|
|
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
|
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
|
|
|
<Tip>
|
|
|
|
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
|
sequence. The token used is the `cls_token`.
|
|
|
|
</Tip>
|
|
|
|
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
|
The end of sequence token.
|
|
|
|
<Tip>
|
|
|
|
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
|
The token used is the `sep_token`.
|
|
|
|
</Tip>
|
|
|
|
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
|
sequence classification or for a text and a question for question answering. It is also used as the last
|
|
token of a sequence built with special tokens.
|
|
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
|
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
|
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
|
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
|
token instead.
|
|
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
|
The token used for padding, for example when batching sequences of different lengths.
|
|
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
|
The token used for masking values. This is the token used when training this model with masked language
|
|
modeling. This is the token which the model will try to predict.
|
|
sp_model_kwargs (`dict`, *optional*):
|
|
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
|
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
|
to set:
|
|
|
|
- `enable_sampling`: Enable subword regularization.
|
|
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
|
|
|
- `nbest_size = {0,1}`: No sampling is performed.
|
|
- `nbest_size > 1`: samples from the nbest_size results.
|
|
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
|
using forward-filtering-and-backward-sampling algorithm.
|
|
|
|
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
|
BPE-dropout.
|
|
|
|
Attributes:
|
|
sp_model (`SentencePieceProcessor`):
|
|
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
|
"""
|
|
|
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
model_input_names = ["input_ids", "attention_mask"]
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_file,
|
|
monolingual_vocab_file,
|
|
bos_token="<s>",
|
|
eos_token="</s>",
|
|
sep_token="</s>",
|
|
cls_token="<s>",
|
|
unk_token="<unk>",
|
|
pad_token="<pad>",
|
|
mask_token="<mask>",
|
|
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
|
**kwargs,
|
|
) -> None:
|
|
# Mask token behave like a normal word, i.e. include the space before it
|
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
|
|
|
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
|
|
|
self.vocab_file = vocab_file
|
|
self.monolingual_vocab_file = monolingual_vocab_file
|
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
|
self.sp_model.Load(str(vocab_file))
|
|
|
|
# Load the reduced vocab
|
|
|
|
# Keep order of special tokens for backward compatibility
|
|
self.fairseq_tokens_to_ids = {}
|
|
cnt = 0
|
|
for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
|
|
if str(token) not in self.fairseq_tokens_to_ids:
|
|
self.fairseq_tokens_to_ids[str(token)] = cnt
|
|
cnt += 1
|
|
with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
|
|
for line in f.readlines():
|
|
token = line.strip().split()[0]
|
|
self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)
|
|
if str(mask_token) not in self.fairseq_tokens_to_ids:
|
|
self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)
|
|
|
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
|
|
|
super().__init__(
|
|
bos_token=bos_token,
|
|
eos_token=eos_token,
|
|
unk_token=unk_token,
|
|
sep_token=sep_token,
|
|
cls_token=cls_token,
|
|
pad_token=pad_token,
|
|
mask_token=mask_token,
|
|
sp_model_kwargs=self.sp_model_kwargs,
|
|
**kwargs,
|
|
)
|
|
|
|
def __getstate__(self):
|
|
state = self.__dict__.copy()
|
|
state["sp_model"] = None
|
|
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
|
return state
|
|
|
|
def __setstate__(self, d):
|
|
self.__dict__ = d
|
|
|
|
# for backward compatibility
|
|
if not hasattr(self, "sp_model_kwargs"):
|
|
self.sp_model_kwargs = {}
|
|
|
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
|
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
|
|
|
def build_inputs_with_special_tokens(
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
|
) -> List[int]:
|
|
"""
|
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
|
adding special tokens. An BARTPho sequence has the following format:
|
|
|
|
- single sequence: `<s> X </s>`
|
|
- pair of sequences: `<s> A </s></s> B </s>`
|
|
|
|
Args:
|
|
token_ids_0 (`List[int]`):
|
|
List of IDs to which the special tokens will be added.
|
|
token_ids_1 (`List[int]`, *optional*):
|
|
Optional second list of IDs for sequence pairs.
|
|
|
|
Returns:
|
|
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
|
"""
|
|
|
|
if token_ids_1 is None:
|
|
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
|
cls = [self.cls_token_id]
|
|
sep = [self.sep_token_id]
|
|
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
|
|
|
def get_special_tokens_mask(
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
|
) -> List[int]:
|
|
"""
|
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
|
special tokens using the tokenizer `prepare_for_model` method.
|
|
|
|
Args:
|
|
token_ids_0 (`List[int]`):
|
|
List of IDs.
|
|
token_ids_1 (`List[int]`, *optional*):
|
|
Optional second list of IDs for sequence pairs.
|
|
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
|
Whether or not the token list is already formatted with special tokens for the model.
|
|
|
|
Returns:
|
|
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
|
"""
|
|
|
|
if already_has_special_tokens:
|
|
return super().get_special_tokens_mask(
|
|
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
|
)
|
|
|
|
if token_ids_1 is None:
|
|
return [1] + ([0] * len(token_ids_0)) + [1]
|
|
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
|
|
|
def create_token_type_ids_from_sequences(
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
|
) -> List[int]:
|
|
"""
|
|
Create a mask from the two sequences passed to be used in a sequence-pair classification task. BARTPho does not
|
|
make use of token type ids, therefore a list of zeros is returned.
|
|
|
|
Args:
|
|
token_ids_0 (`List[int]`):
|
|
List of IDs.
|
|
token_ids_1 (`List[int]`, *optional*):
|
|
Optional second list of IDs for sequence pairs.
|
|
|
|
Returns:
|
|
`List[int]`: List of zeros.
|
|
|
|
"""
|
|
|
|
sep = [self.sep_token_id]
|
|
cls = [self.cls_token_id]
|
|
|
|
if token_ids_1 is None:
|
|
return len(cls + token_ids_0 + sep) * [0]
|
|
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return len(self.fairseq_ids_to_tokens)
|
|
|
|
def get_vocab(self):
|
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
|
vocab.update(self.added_tokens_encoder)
|
|
return vocab
|
|
|
|
def _tokenize(self, text: str) -> List[str]:
|
|
return self.sp_model.encode(text, out_type=str)
|
|
|
|
def _convert_token_to_id(self, token):
|
|
"""Converts a token (str) in an id using the vocab."""
|
|
if token in self.fairseq_tokens_to_ids:
|
|
return self.fairseq_tokens_to_ids[token]
|
|
else:
|
|
return self.unk_token_id
|
|
|
|
def _convert_id_to_token(self, index):
|
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
|
return self.fairseq_ids_to_tokens[index]
|
|
|
|
def convert_tokens_to_string(self, tokens):
|
|
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
|
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
|
return out_string
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
|
if not os.path.isdir(save_directory):
|
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
|
return
|
|
out_vocab_file = os.path.join(
|
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
|
)
|
|
out_monolingual_vocab_file = os.path.join(
|
|
save_directory,
|
|
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
|
|
)
|
|
|
|
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
|
copyfile(self.vocab_file, out_vocab_file)
|
|
elif not os.path.isfile(self.vocab_file):
|
|
with open(out_vocab_file, "wb") as fi:
|
|
content_spiece_model = self.sp_model.serialized_model_proto()
|
|
fi.write(content_spiece_model)
|
|
|
|
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
|
|
out_monolingual_vocab_file
|
|
) and os.path.isfile(self.monolingual_vocab_file):
|
|
copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
|
|
elif not os.path.isfile(self.monolingual_vocab_file):
|
|
with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
|
|
for token in self.fairseq_tokens_to_ids:
|
|
if token not in self.all_special_tokens:
|
|
fp.write(f"{str(token)} \n")
|
|
|
|
return out_vocab_file, out_monolingual_vocab_file
|