150 lines
5.3 KiB
Python
150 lines
5.3 KiB
Python
|
import pickle
|
|||
|
import re
|
|||
|
|
|||
|
import pytest
|
|||
|
|
|||
|
from spacy.attrs import ENT_IOB, ENT_TYPE
|
|||
|
from spacy.lang.en import English
|
|||
|
from spacy.tokenizer import Tokenizer
|
|||
|
from spacy.tokens import Doc
|
|||
|
from spacy.util import (
|
|||
|
compile_infix_regex,
|
|||
|
compile_prefix_regex,
|
|||
|
compile_suffix_regex,
|
|||
|
get_lang_class,
|
|||
|
load_model,
|
|||
|
)
|
|||
|
|
|||
|
from ..util import assert_packed_msg_equal, make_tempdir
|
|||
|
|
|||
|
|
|||
|
def load_tokenizer(b):
|
|||
|
tok = get_lang_class("en")().tokenizer
|
|||
|
tok.from_bytes(b)
|
|||
|
return tok
|
|||
|
|
|||
|
|
|||
|
@pytest.mark.issue(2833)
|
|||
|
def test_issue2833(en_vocab):
|
|||
|
"""Test that a custom error is raised if a token or span is pickled."""
|
|||
|
doc = Doc(en_vocab, words=["Hello", "world"])
|
|||
|
with pytest.raises(NotImplementedError):
|
|||
|
pickle.dumps(doc[0])
|
|||
|
with pytest.raises(NotImplementedError):
|
|||
|
pickle.dumps(doc[0:2])
|
|||
|
|
|||
|
|
|||
|
@pytest.mark.issue(3012)
|
|||
|
def test_issue3012(en_vocab):
|
|||
|
"""Test that the is_tagged attribute doesn't get overwritten when we from_array
|
|||
|
without tag information."""
|
|||
|
words = ["This", "is", "10", "%", "."]
|
|||
|
tags = ["DT", "VBZ", "CD", "NN", "."]
|
|||
|
pos = ["DET", "VERB", "NUM", "NOUN", "PUNCT"]
|
|||
|
ents = ["O", "O", "B-PERCENT", "I-PERCENT", "O"]
|
|||
|
doc = Doc(en_vocab, words=words, tags=tags, pos=pos, ents=ents)
|
|||
|
assert doc.has_annotation("TAG")
|
|||
|
expected = ("10", "NUM", "CD", "PERCENT")
|
|||
|
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
|
|||
|
header = [ENT_IOB, ENT_TYPE]
|
|||
|
ent_array = doc.to_array(header)
|
|||
|
doc.from_array(header, ent_array)
|
|||
|
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
|
|||
|
# Serializing then deserializing
|
|||
|
doc_bytes = doc.to_bytes()
|
|||
|
doc2 = Doc(en_vocab).from_bytes(doc_bytes)
|
|||
|
assert (doc2[2].text, doc2[2].pos_, doc2[2].tag_, doc2[2].ent_type_) == expected
|
|||
|
|
|||
|
|
|||
|
@pytest.mark.issue(4190)
|
|||
|
def test_issue4190():
|
|||
|
def customize_tokenizer(nlp):
|
|||
|
prefix_re = compile_prefix_regex(nlp.Defaults.prefixes)
|
|||
|
suffix_re = compile_suffix_regex(nlp.Defaults.suffixes)
|
|||
|
infix_re = compile_infix_regex(nlp.Defaults.infixes)
|
|||
|
# Remove all exceptions where a single letter is followed by a period (e.g. 'h.')
|
|||
|
exceptions = {
|
|||
|
k: v
|
|||
|
for k, v in dict(nlp.Defaults.tokenizer_exceptions).items()
|
|||
|
if not (len(k) == 2 and k[1] == ".")
|
|||
|
}
|
|||
|
new_tokenizer = Tokenizer(
|
|||
|
nlp.vocab,
|
|||
|
exceptions,
|
|||
|
prefix_search=prefix_re.search,
|
|||
|
suffix_search=suffix_re.search,
|
|||
|
infix_finditer=infix_re.finditer,
|
|||
|
token_match=nlp.tokenizer.token_match,
|
|||
|
faster_heuristics=False,
|
|||
|
)
|
|||
|
nlp.tokenizer = new_tokenizer
|
|||
|
|
|||
|
test_string = "Test c."
|
|||
|
# Load default language
|
|||
|
nlp_1 = English()
|
|||
|
doc_1a = nlp_1(test_string)
|
|||
|
result_1a = [token.text for token in doc_1a] # noqa: F841
|
|||
|
# Modify tokenizer
|
|||
|
customize_tokenizer(nlp_1)
|
|||
|
doc_1b = nlp_1(test_string)
|
|||
|
result_1b = [token.text for token in doc_1b]
|
|||
|
# Save and Reload
|
|||
|
with make_tempdir() as model_dir:
|
|||
|
nlp_1.to_disk(model_dir)
|
|||
|
nlp_2 = load_model(model_dir)
|
|||
|
# This should be the modified tokenizer
|
|||
|
doc_2 = nlp_2(test_string)
|
|||
|
result_2 = [token.text for token in doc_2]
|
|||
|
assert result_1b == result_2
|
|||
|
assert nlp_2.tokenizer.faster_heuristics is False
|
|||
|
|
|||
|
|
|||
|
def test_serialize_custom_tokenizer(en_vocab, en_tokenizer):
|
|||
|
"""Test that custom tokenizer with not all functions defined or empty
|
|||
|
properties can be serialized and deserialized correctly (see #2494,
|
|||
|
#4991)."""
|
|||
|
tokenizer = Tokenizer(en_vocab, suffix_search=en_tokenizer.suffix_search)
|
|||
|
tokenizer_bytes = tokenizer.to_bytes()
|
|||
|
Tokenizer(en_vocab).from_bytes(tokenizer_bytes)
|
|||
|
|
|||
|
# test that empty/unset values are set correctly on deserialization
|
|||
|
tokenizer = get_lang_class("en")().tokenizer
|
|||
|
tokenizer.token_match = re.compile("test").match
|
|||
|
assert tokenizer.rules != {}
|
|||
|
assert tokenizer.token_match is not None
|
|||
|
assert tokenizer.url_match is not None
|
|||
|
assert tokenizer.prefix_search is not None
|
|||
|
assert tokenizer.infix_finditer is not None
|
|||
|
tokenizer.from_bytes(tokenizer_bytes)
|
|||
|
assert tokenizer.rules == {}
|
|||
|
assert tokenizer.token_match is None
|
|||
|
assert tokenizer.url_match is None
|
|||
|
assert tokenizer.prefix_search is None
|
|||
|
assert tokenizer.infix_finditer is None
|
|||
|
|
|||
|
tokenizer = Tokenizer(en_vocab, rules={"ABC.": [{"ORTH": "ABC"}, {"ORTH": "."}]})
|
|||
|
tokenizer.rules = {}
|
|||
|
tokenizer_bytes = tokenizer.to_bytes()
|
|||
|
tokenizer_reloaded = Tokenizer(en_vocab).from_bytes(tokenizer_bytes)
|
|||
|
assert tokenizer_reloaded.rules == {}
|
|||
|
|
|||
|
|
|||
|
@pytest.mark.parametrize("text", ["I💜you", "they’re", "“hello”"])
|
|||
|
def test_serialize_tokenizer_roundtrip_bytes(en_tokenizer, text):
|
|||
|
tokenizer = en_tokenizer
|
|||
|
new_tokenizer = load_tokenizer(tokenizer.to_bytes())
|
|||
|
assert_packed_msg_equal(new_tokenizer.to_bytes(), tokenizer.to_bytes())
|
|||
|
assert new_tokenizer.to_bytes() == tokenizer.to_bytes()
|
|||
|
doc1 = tokenizer(text)
|
|||
|
doc2 = new_tokenizer(text)
|
|||
|
assert [token.text for token in doc1] == [token.text for token in doc2]
|
|||
|
|
|||
|
|
|||
|
def test_serialize_tokenizer_roundtrip_disk(en_tokenizer):
|
|||
|
tokenizer = en_tokenizer
|
|||
|
with make_tempdir() as d:
|
|||
|
file_path = d / "tokenizer"
|
|||
|
tokenizer.to_disk(file_path)
|
|||
|
tokenizer_d = en_tokenizer.from_disk(file_path)
|
|||
|
assert tokenizer.to_bytes() == tokenizer_d.to_bytes()
|