ai-content-maker/.venv/Lib/site-packages/gruut/pos.py

455 lines
13 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
#!/usr/bin/env python3
"""
Part of speech tagging using python CRF suite.
Credit to: https://towardsdatascience.com/pos-tagging-using-crfs-ea430c5fb78b
Training requires conllu package:
.. code-block:: sh
pip install conllu
Training data comes from Univeral Dependencies (https://universaldependencies.org/)
Example:
.. code-block:: sh
python3 -m gruut.pos train --conllu train.conllu --output model.crf --label xpos
Pre-trained models have the following settings:
* c1 = 0.25
* c2 = 0.3
* max-iterations = 100
English model is trained with "xpos" label.
French model is trained with "upos" label.
"""
import argparse
import base64
import logging
import os
import string
import sys
import time
import typing
from pathlib import Path
import jsonlines
import pycrfsuite
_LOGGER = logging.getLogger("gruut.pos")
# -----------------------------------------------------------------------------
FEATURES_TYPE = typing.Dict[
str, typing.Union[str, bool, int, float, typing.Sequence[str]]
]
class PartOfSpeechTagger:
"""Part of speech tagger using a pre-trained CRF model"""
def __init__(
self, crf_tagger: typing.Union[str, Path, pycrfsuite.Tagger], **kwargs
):
if isinstance(crf_tagger, pycrfsuite.Tagger):
self.crf_tagger = crf_tagger
else:
# Load model
self.crf_tagger = pycrfsuite.Tagger()
self.crf_tagger.open(str(crf_tagger))
def __call__(self, words: typing.Sequence[str]) -> typing.Sequence[str]:
"""Returns POS tag for each word"""
features = PartOfSpeechTagger.sent2features(words)
return self.crf_tagger.tag(features)
@staticmethod
def local_features(
word: str,
prefix: str = "",
bias: float = 1.0,
add_punctuation: bool = True,
add_digit: bool = True,
add_length: bool = True,
chars_front: int = 3,
chars_back: int = 3,
encode: bool = True,
) -> FEATURES_TYPE:
"""Get features for a single word"""
features: FEATURES_TYPE = {
f"{prefix}bias": bias,
f"{prefix}word": PartOfSpeechTagger.encode_string(word) if encode else word,
}
if add_length:
features[f"{prefix}len(word)"] = len(word)
if add_punctuation:
features[f"{prefix}word.ispunctuation"] = word in string.punctuation
if add_digit:
features[f"{prefix}word.isdigit()"] = word.isdigit()
# Chunks from front
for i in range(2, chars_front + 1):
features[f"{prefix}word[:{i}]"] = word[:i]
# Chunks from pack
for i in range(2, chars_back + 1):
features[f"{prefix}word[-{i}:]"] = word[-i:]
return features
@staticmethod
def word2features(
sentence: typing.Sequence[str],
i: int,
add_bos: bool = True,
add_eos: bool = True,
words_backward: int = 2,
words_forward: int = 2,
**kwargs,
) -> FEATURES_TYPE:
"""Get features for a word and surrounding context"""
word = sentence[i]
num_words = len(sentence)
features = PartOfSpeechTagger.local_features(word, **kwargs)
if (i == 0) and add_bos:
features["BOS"] = True
if (i == (num_words - 1)) and add_eos:
features["EOS"] = True
for j in range(1, words_backward + 1):
if i >= j:
word_prev = sentence[i - j]
features.update(
PartOfSpeechTagger.local_features(
word_prev, prefix=f"-{j}:", **kwargs
)
)
for j in range(1, words_forward + 1):
if i < (num_words - j):
word_next = sentence[i + j]
features.update(
PartOfSpeechTagger.local_features(
word_next, prefix=f"+{j}:", **kwargs
)
)
return features
@staticmethod
def sent2features(
sentence: typing.Sequence[str], **kwargs
) -> typing.List[FEATURES_TYPE]:
"""Get features for all words in a sentence"""
return [
PartOfSpeechTagger.word2features(sentence, i, **kwargs)
for i in range(len(sentence))
]
@staticmethod
def encode_string(s: str) -> str:
"""Encodes string in a form that crfsuite will accept (ASCII) and can be decoded"""
return base64.b64encode(s.encode()).decode("ascii")
@staticmethod
def decode_string(s: str) -> str:
"""Decodes a string encoded by encode_string"""
return base64.b64decode(s.encode("ascii")).decode()
# -----------------------------------------------------------------------------
def train_model(
conllu_path: typing.Union[str, Path],
output_path: typing.Union[str, Path],
label: str = "xpos",
c1: float = 0.25,
c2: float = 0.3,
max_iterations: int = 100,
):
"""Train a new model from CONLLU data"""
try:
import conllu
except ImportError as e:
_LOGGER.critical("conllu package is required for training")
_LOGGER.critical("pip install 'conllu>=4.4'")
raise e
conllu_path = Path(conllu_path)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
_LOGGER.debug("Loading train file (%s)", conllu_path)
with open(conllu_path, "r", encoding="utf-8") as conllu_file:
train_sents = conllu.parse(conllu_file.read())
_LOGGER.debug("Training model for %s max iteration(s)", max_iterations)
trainer = pycrfsuite.Trainer(verbose=False)
_LOGGER.debug("Getting features for %s training sentence(s)", len(train_sents))
for sent in train_sents:
words = [token["form"] for token in sent]
features = PartOfSpeechTagger.sent2features(words)
labels = []
skip_sent = False
for token in sent:
token_label = token.get(label)
if token_label is None:
_LOGGER.warning("Example has empty label for %s: %s", token, sent)
skip_sent = True
break
labels.append(token_label)
if skip_sent:
continue
trainer.append(features, labels)
trainer.set_params(
{
"c1": c1, # coefficient for L1 penalty
"c2": c2, # coefficient for L2 penalty
"max_iterations": max_iterations, # stop earlier
# include transitions that are possible, but not observed
"feature.possible_transitions": True,
}
)
_LOGGER.debug(trainer.get_params())
# Begin training
_LOGGER.info("Training")
start_time = time.perf_counter()
trainer.train(str(output_path))
end_time = time.perf_counter()
_LOGGER.info("Training completed in %s second(s)", end_time - start_time)
_LOGGER.info(trainer.logparser.last_iteration)
def do_train(args):
"""CLI method for train_model"""
train_model(
conllu_path=args.conllu,
output_path=args.output,
label=args.label,
c1=args.c1,
c2=args.c2,
max_iterations=args.max_iterations,
)
# -----------------------------------------------------------------------------
def do_print_labels(args):
"""Print label set from a CONLLU file"""
try:
import conllu
except ImportError as e:
_LOGGER.critical("conllu package is required for training")
_LOGGER.critical("pip install 'conllu>=4.4'")
raise e
labels = set()
with open(args.conllu, "r", encoding="utf-8") as conllu_file:
for sent in conllu.parse(conllu_file.read()):
for token in sent:
token_label = token.get(args.label)
if token_label is not None:
labels.add(token_label)
print(sorted(labels))
# -----------------------------------------------------------------------------
def do_predict(args):
"""CLI method for predict"""
tagger = PartOfSpeechTagger(args.model)
if args.texts:
lines = args.texts
else:
lines = sys.stdin
if os.isatty(sys.stdin.fileno()):
print("Reading sentences from stdin...", file=sys.stderr)
writer = jsonlines.Writer(sys.stdout, flush=True)
for line in lines:
line = line.strip()
if not line:
continue
words = line.split()
words_and_tags = list(zip(words, tagger(words)))
writer.write(words_and_tags)
def do_test(args):
"""CLI method for testing"""
try:
import conllu
except ImportError as e:
_LOGGER.critical("conllu package is required for training")
_LOGGER.critical("pip install 'conllu>=4.4'")
raise e
tagger = PartOfSpeechTagger(args.model)
_LOGGER.debug("Testing file (%s)", args.conllu)
num_sentences = 0
num_words = 0
sents_with_errors = 0
total_errors = 0
with open(args.conllu, "r", encoding="utf-8") as conllu_file:
for sent in conllu.parse(conllu_file.read()):
words = [token["form"] for token in sent]
actual_labels = [token.get(args.label) for token in sent]
expected_labels = tagger(words)
had_error = False
for actual, expected in zip(actual_labels, expected_labels):
if actual != expected:
total_errors += 1
had_error = True
num_words += 1
if had_error:
sents_with_errors += 1
num_sentences += 1
if (num_sentences < 1) or (num_words < 1):
return
print(
"{0} out of {1} word(s) had an incorrect tag ({2:0.2f}%)".format(
total_errors, num_words, total_errors / num_words
)
)
print(
"{0} out of {1} sentence(s) had at least one error ({2:0.2f}%)".format(
sents_with_errors, num_sentences, sents_with_errors / num_sentences
)
)
# -----------------------------------------------------------------------------
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(prog="pos.py")
# Create subparsers for each sub-command
sub_parsers = parser.add_subparsers()
sub_parsers.required = True
sub_parsers.dest = "command"
# -----
# Train
# -----
train_parser = sub_parsers.add_parser(
"train", help="Train a new POS model from CONLLU file"
)
train_parser.add_argument(
"--conllu", required=True, help="CONLLU file with training data"
)
train_parser.add_argument(
"--output", required=True, help="Path to write output model"
)
train_parser.add_argument(
"--label", default="xpos", help="Field to predict in training data"
)
train_parser.add_argument("--c1", type=float, default=0.25, help="L1 penalty")
train_parser.add_argument("--c2", type=float, default=0.3, help="L2 penalty")
train_parser.add_argument(
"--max-iterations",
type=int,
default=100,
help="Maximum number of iterations to train for",
)
train_parser.set_defaults(func=do_train)
# ----
# Test
# ----
test_parser = sub_parsers.add_parser(
"test", help="Test a POS model on a CONLLU file"
)
test_parser.add_argument("--model", required=True, help="Path to POS tagger model")
test_parser.add_argument(
"--conllu", required=True, help="CONLLU file with testing data"
)
test_parser.add_argument(
"--label", default="xpos", help="Field to predict in training data"
)
test_parser.set_defaults(func=do_test)
# ------------
# Print Labels
# ------------
print_labels_parser = sub_parsers.add_parser(
"print-labels", help="Print set of unique labels from a CONLLU file"
)
print_labels_parser.add_argument(
"--conllu", required=True, help="CONLLU file with training data"
)
print_labels_parser.add_argument(
"--label", default="xpos", help="Field to predict in training data"
)
print_labels_parser.set_defaults(func=do_print_labels)
# -------
# Predict
# -------
predict_parser = sub_parsers.add_parser(
"predict", help="Predict POS tags for sentence(s)"
)
predict_parser.add_argument(
"--model", required=True, help="Path to POS tagger model"
)
predict_parser.add_argument("texts", nargs="*", help="Sentences")
predict_parser.set_defaults(func=do_predict)
# ----------------
# Shared arguments
# ----------------
for sub_parser in [train_parser, predict_parser, test_parser, print_labels_parser]:
sub_parser.add_argument(
"--debug", action="store_true", help="Print DEBUG messages to console"
)
args = parser.parse_args()
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
_LOGGER.debug(args)
args.func(args)
if __name__ == "__main__":
main()