# coding=utf-8 # Copyright 2020 The Google AI Language Team Authors, Allegro.pl, Facebook Inc. and the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import re import unicodedata from typing import List, Optional, Tuple from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...utils import logging logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "merges_file": "merges.txt", } # Copied from transformers.models.xlm.tokenization_xlm.get_pairs def get_pairs(word): """ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length strings) """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs # Copied from transformers.models.xlm.tokenization_xlm.replace_unicode_punct def replace_unicode_punct(text): """ Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl """ text = text.replace(",", ",") text = re.sub(r"。\s*", ". ", text) text = text.replace("、", ",") text = text.replace("”", '"') text = text.replace("“", '"') text = text.replace("∶", ":") text = text.replace(":", ":") text = text.replace("?", "?") text = text.replace("《", '"') text = text.replace("》", '"') text = text.replace(")", ")") text = text.replace("!", "!") text = text.replace("(", "(") text = text.replace(";", ";") text = text.replace("1", "1") text = text.replace("」", '"') text = text.replace("「", '"') text = text.replace("0", "0") text = text.replace("3", "3") text = text.replace("2", "2") text = text.replace("5", "5") text = text.replace("6", "6") text = text.replace("9", "9") text = text.replace("7", "7") text = text.replace("8", "8") text = text.replace("4", "4") text = re.sub(r".\s*", ". ", text) text = text.replace("~", "~") text = text.replace("’", "'") text = text.replace("…", "...") text = text.replace("━", "-") text = text.replace("〈", "<") text = text.replace("〉", ">") text = text.replace("【", "[") text = text.replace("】", "]") text = text.replace("%", "%") return text # Copied from transformers.models.xlm.tokenization_xlm.remove_non_printing_char def remove_non_printing_char(text): """ Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl """ output = [] for char in text: cat = unicodedata.category(char) if cat.startswith("C"): continue output.append(char) return "".join(output) # Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a piece of text.""" text = text.strip() if not text: return [] tokens = text.split() return tokens # Copied from transformers.models.bert.tokenization_bert.BasicTokenizer class BasicTokenizer(object): """ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). Args: do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. never_split (`Iterable`, *optional*): Collection of tokens which will never be split during tokenization. Only has an effect when `do_basic_tokenize=True` tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this [issue](https://github.com/huggingface/transformers/issues/328)). strip_accents (`bool`, *optional*): Whether or not to strip all accents. If this option is not specified, then it will be determined by the value for `lowercase` (as in the original BERT). do_split_on_punc (`bool`, *optional*, defaults to `True`): In some instances we want to skip the basic punctuation splitting so that later tokenization can capture the full context of the words, such as contractions. """ def __init__( self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None, do_split_on_punc=True, ): if never_split is None: never_split = [] self.do_lower_case = do_lower_case self.never_split = set(never_split) self.tokenize_chinese_chars = tokenize_chinese_chars self.strip_accents = strip_accents self.do_split_on_punc = do_split_on_punc def tokenize(self, text, never_split=None): """ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. Args: never_split (`List[str]`, *optional*) Kept for backward compatibility purposes. Now implemented directly at the base class level (see [`PreTrainedTokenizer.tokenize`]) List of token not to split. """ # union() returns a new set by concatenating the two sets. never_split = self.never_split.union(set(never_split)) if never_split else self.never_split text = self._clean_text(text) # This was added on November 1st, 2018 for the multilingual and Chinese # models. This is also applied to the English models now, but it doesn't # matter since the English models were not trained on any Chinese data # and generally don't have any Chinese data in them (there are Chinese # characters in the vocabulary because Wikipedia does have some Chinese # words in the English Wikipedia.). if self.tokenize_chinese_chars: text = self._tokenize_chinese_chars(text) # prevents treating the same character with different unicode codepoints as different characters unicode_normalized_text = unicodedata.normalize("NFC", text) orig_tokens = whitespace_tokenize(unicode_normalized_text) split_tokens = [] for token in orig_tokens: if token not in never_split: if self.do_lower_case: token = token.lower() if self.strip_accents is not False: token = self._run_strip_accents(token) elif self.strip_accents: token = self._run_strip_accents(token) split_tokens.extend(self._run_split_on_punc(token, never_split)) output_tokens = whitespace_tokenize(" ".join(split_tokens)) return output_tokens def _run_strip_accents(self, text): """Strips accents from a piece of text.""" text = unicodedata.normalize("NFD", text) output = [] for char in text: cat = unicodedata.category(char) if cat == "Mn": continue output.append(char) return "".join(output) def _run_split_on_punc(self, text, never_split=None): """Splits punctuation on a piece of text.""" if not self.do_split_on_punc or (never_split is not None and text in never_split): return [text] chars = list(text) i = 0 start_new_word = True output = [] while i < len(chars): char = chars[i] if _is_punctuation(char): output.append([char]) start_new_word = True else: if start_new_word: output.append([]) start_new_word = False output[-1].append(char) i += 1 return ["".join(x) for x in output] def _tokenize_chinese_chars(self, text): """Adds whitespace around any CJK character.""" output = [] for char in text: cp = ord(char) if self._is_chinese_char(cp): output.append(" ") output.append(char) output.append(" ") else: output.append(char) return "".join(output) def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False def _clean_text(self, text): """Performs invalid character removal and whitespace cleanup on text.""" output = [] for char in text: cp = ord(char) if cp == 0 or cp == 0xFFFD or _is_control(char): continue if _is_whitespace(char): output.append(" ") else: output.append(char) return "".join(output) class HerbertTokenizer(PreTrainedTokenizer): """ Construct a BPE tokenizer for HerBERT. Peculiarities: - uses BERT's pre-tokenizer: BaseTokenizer splits tokens on spaces, and also on punctuation. Each occurrence of a punctuation character will be treated separately. - Such pretokenized input is BPE subtokenized This tokenizer inherits from [`XLMTokenizer`] which contains most of the methods. Users should refer to the superclass for more information regarding methods. """ vocab_files_names = VOCAB_FILES_NAMES def __init__( self, vocab_file, merges_file, tokenizer_file=None, cls_token="", unk_token="", pad_token="", mask_token="", sep_token="", bos_token="", do_lowercase_and_remove_accent=False, additional_special_tokens=[ "", "", "", "", "", "", "", "", "", "", ], lang2id=None, id2lang=None, **kwargs, ): try: import sacremoses except ImportError: raise ImportError( "You need to install sacremoses to use HerbertTokenizer. " "See https://pypi.org/project/sacremoses/ for installation." ) self.sm = sacremoses # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = {} # cache of sm.MosesTokenizer instance self.cache_moses_tokenizer = {} self.lang_with_custom_tokenizer = {"zh", "th", "ja"} # True for current supported model (v1.2.0), False for XLM-17 & 100 self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent self.lang2id = lang2id self.id2lang = id2lang if lang2id is not None and id2lang is not None: assert len(lang2id) == len(id2lang) self.ja_word_tokenizer = None self.zh_word_tokenizer = None with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} with open(merges_file, encoding="utf-8") as merges_handle: merges = merges_handle.read().split("\n")[:-1] merges = [tuple(merge.split()[:2]) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} super().__init__( unk_token=unk_token, bos_token=bos_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, lang2id=lang2id, id2lang=id2lang, do_lowercase_and_remove_accent=do_lowercase_and_remove_accent, tokenizer_file=None, **kwargs, ) self.bert_pre_tokenizer = BasicTokenizer( do_lower_case=False, never_split=self.all_special_tokens, tokenize_chinese_chars=False, strip_accents=False, ) @property # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.do_lower_case def do_lower_case(self): return self.do_lowercase_and_remove_accent # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_punct_norm def moses_punct_norm(self, text, lang): if lang not in self.cache_moses_punct_normalizer: punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang) self.cache_moses_punct_normalizer[lang] = punct_normalizer else: punct_normalizer = self.cache_moses_punct_normalizer[lang] return punct_normalizer.normalize(text) # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_tokenize def moses_tokenize(self, text, lang): if lang not in self.cache_moses_tokenizer: moses_tokenizer = self.sm.MosesTokenizer(lang=lang) self.cache_moses_tokenizer[lang] = moses_tokenizer else: moses_tokenizer = self.cache_moses_tokenizer[lang] return moses_tokenizer.tokenize(text, return_str=False, escape=False) # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_pipeline def moses_pipeline(self, text, lang): text = replace_unicode_punct(text) text = self.moses_punct_norm(text, lang) text = remove_non_printing_char(text) return text # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.ja_tokenize def ja_tokenize(self, text): if self.ja_word_tokenizer is None: try: import Mykytea self.ja_word_tokenizer = Mykytea.Mykytea( f"-model {os.path.expanduser('~')}/local/share/kytea/model.bin" ) except (AttributeError, ImportError): logger.error( "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper" " (https://github.com/chezou/Mykytea-python) with the following steps" ) logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") logger.error("2. autoreconf -i") logger.error("3. ./configure --prefix=$HOME/local") logger.error("4. make && make install") logger.error("5. pip install kytea") raise return list(self.ja_word_tokenizer.getWS(text)) @property # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.vocab_size def vocab_size(self): return len(self.encoder) # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_vocab def get_vocab(self): return dict(self.encoder, **self.added_tokens_encoder) # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.bpe def bpe(self, token): word = tuple(token[:-1]) + (token[-1] + "",) if token in self.cache: return self.cache[token] pairs = get_pairs(word) if not pairs: return token + "" while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) except ValueError: new_word.extend(word[i:]) break else: new_word.extend(word[i:j]) i = j if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = " ".join(word) if word == "\n ": word = "\n" self.cache[token] = word return word def _tokenize(self, text): pre_tokens = self.bert_pre_tokenizer.tokenize(text) split_tokens = [] for token in pre_tokens: if token: split_tokens.extend(list(self.bpe(token).split(" "))) return split_tokens # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.encoder.get(token, self.encoder.get(self.unk_token)) # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.decoder.get(index, self.unk_token) # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.convert_tokens_to_string def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" out_string = "".join(tokens).replace("", " ").strip() return out_string # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.build_inputs_with_special_tokens def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. An XLM sequence has the following format: - single sequence: ` X ` - pair of sequences: ` A B ` Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ bos = [self.bos_token_id] sep = [self.sep_token_id] if token_ids_1 is None: return bos + token_ids_0 + sep return bos + token_ids_0 + sep + token_ids_1 + sep # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_special_tokens_mask def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. Returns: `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.create_token_type_ids_from_sequences def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence pair mask has the following format: ``` 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | ``` If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). """ sep = [self.sep_token_id] cls = [self.cls_token_id] if token_ids_1 is None: return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.save_vocabulary def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) merge_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] ) with open(vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") index = 0 with open(merge_file, "w", encoding="utf-8") as writer: for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): if index != token_index: logger.warning( f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." " Please check that the tokenizer is not corrupted!" ) index = token_index writer.write(" ".join(bpe_tokens) + "\n") index += 1 return vocab_file, merge_file # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__getstate__ def __getstate__(self): state = self.__dict__.copy() state["sm"] = None return state # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__setstate__ def __setstate__(self, d): self.__dict__ = d try: import sacremoses except ImportError: raise ImportError( "You need to install sacremoses to use XLMTokenizer. " "See https://pypi.org/project/sacremoses/ for installation." ) self.sm = sacremoses