56 lines
2.5 KiB
Python
56 lines
2.5 KiB
Python
import pytest
|
|
from spacy.lang.en import English
|
|
from spacy.training import Example
|
|
from thinc.api import Config
|
|
|
|
default_tok2vec_config = """
|
|
[model]
|
|
@architectures = "spacy-legacy.HashEmbedCNN.v1"
|
|
pretrained_vectors = null
|
|
width = 96
|
|
depth = 4
|
|
embed_size = 2000
|
|
window_size = 1
|
|
maxout_pieces = 3
|
|
subword_features = true
|
|
"""
|
|
DEFAULT_TOK2VEC_MODEL = Config().from_str(default_tok2vec_config)["model"]
|
|
|
|
|
|
TRAIN_DATA_SINGLE_LABEL = [
|
|
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
|
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
|
]
|
|
|
|
TRAIN_DATA_MULTI_LABEL = [
|
|
("I'm angry and confused", {"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}}),
|
|
("I'm confused but happy", {"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}}),
|
|
]
|
|
|
|
# fmt: off
|
|
@pytest.mark.parametrize(
|
|
"name,train_data,textcat_config",
|
|
[
|
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy-legacy.TextCatCNN.v1", "exclusive_classes": True, "tok2vec": DEFAULT_TOK2VEC_MODEL}),
|
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy-legacy.TextCatCNN.v1", "exclusive_classes": False, "tok2vec": DEFAULT_TOK2VEC_MODEL}),
|
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy-legacy.TextCatBOW.v1", "exclusive_classes": False,"ngram_size": 2, "no_output_layer": True}),
|
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy-legacy.TextCatBOW.v1", "exclusive_classes": True,"ngram_size": 3, "no_output_layer": True}),
|
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy-legacy.TextCatEnsemble.v1", "width": 32, "embed_size":16, "exclusive_classes": True,"ngram_size":2,"window_size":4,"conv_depth":1,"pretrained_vectors":False,"dropout":0.1}),
|
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy-legacy.TextCatEnsemble.v1", "width": 32, "embed_size":16, "exclusive_classes": False,"ngram_size":3,"window_size":6,"conv_depth":2,"pretrained_vectors":False,"dropout":0.2}),
|
|
],
|
|
)
|
|
# fmt: on
|
|
def test_textcat(name, train_data, textcat_config):
|
|
pipe_config = {"model": textcat_config}
|
|
nlp = English()
|
|
textcat = nlp.add_pipe(name, config=pipe_config)
|
|
train_examples = []
|
|
for text, annotations in train_data:
|
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
|
for label, value in annotations.get("cats").items():
|
|
textcat.add_label(label)
|
|
optimizer = nlp.initialize()
|
|
for i in range(5):
|
|
losses = {}
|
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|