#!/usr/bin/env python3 """ Part of speech tagging using python CRF suite. Credit to: https://towardsdatascience.com/pos-tagging-using-crfs-ea430c5fb78b Training requires conllu package: .. code-block:: sh pip install conllu Training data comes from Univeral Dependencies (https://universaldependencies.org/) Example: .. code-block:: sh python3 -m gruut.pos train --conllu train.conllu --output model.crf --label xpos Pre-trained models have the following settings: * c1 = 0.25 * c2 = 0.3 * max-iterations = 100 English model is trained with "xpos" label. French model is trained with "upos" label. """ import argparse import base64 import logging import os import string import sys import time import typing from pathlib import Path import jsonlines import pycrfsuite _LOGGER = logging.getLogger("gruut.pos") # ----------------------------------------------------------------------------- FEATURES_TYPE = typing.Dict[ str, typing.Union[str, bool, int, float, typing.Sequence[str]] ] class PartOfSpeechTagger: """Part of speech tagger using a pre-trained CRF model""" def __init__( self, crf_tagger: typing.Union[str, Path, pycrfsuite.Tagger], **kwargs ): if isinstance(crf_tagger, pycrfsuite.Tagger): self.crf_tagger = crf_tagger else: # Load model self.crf_tagger = pycrfsuite.Tagger() self.crf_tagger.open(str(crf_tagger)) def __call__(self, words: typing.Sequence[str]) -> typing.Sequence[str]: """Returns POS tag for each word""" features = PartOfSpeechTagger.sent2features(words) return self.crf_tagger.tag(features) @staticmethod def local_features( word: str, prefix: str = "", bias: float = 1.0, add_punctuation: bool = True, add_digit: bool = True, add_length: bool = True, chars_front: int = 3, chars_back: int = 3, encode: bool = True, ) -> FEATURES_TYPE: """Get features for a single word""" features: FEATURES_TYPE = { f"{prefix}bias": bias, f"{prefix}word": PartOfSpeechTagger.encode_string(word) if encode else word, } if add_length: features[f"{prefix}len(word)"] = len(word) if add_punctuation: features[f"{prefix}word.ispunctuation"] = word in string.punctuation if add_digit: features[f"{prefix}word.isdigit()"] = word.isdigit() # Chunks from front for i in range(2, chars_front + 1): features[f"{prefix}word[:{i}]"] = word[:i] # Chunks from pack for i in range(2, chars_back + 1): features[f"{prefix}word[-{i}:]"] = word[-i:] return features @staticmethod def word2features( sentence: typing.Sequence[str], i: int, add_bos: bool = True, add_eos: bool = True, words_backward: int = 2, words_forward: int = 2, **kwargs, ) -> FEATURES_TYPE: """Get features for a word and surrounding context""" word = sentence[i] num_words = len(sentence) features = PartOfSpeechTagger.local_features(word, **kwargs) if (i == 0) and add_bos: features["BOS"] = True if (i == (num_words - 1)) and add_eos: features["EOS"] = True for j in range(1, words_backward + 1): if i >= j: word_prev = sentence[i - j] features.update( PartOfSpeechTagger.local_features( word_prev, prefix=f"-{j}:", **kwargs ) ) for j in range(1, words_forward + 1): if i < (num_words - j): word_next = sentence[i + j] features.update( PartOfSpeechTagger.local_features( word_next, prefix=f"+{j}:", **kwargs ) ) return features @staticmethod def sent2features( sentence: typing.Sequence[str], **kwargs ) -> typing.List[FEATURES_TYPE]: """Get features for all words in a sentence""" return [ PartOfSpeechTagger.word2features(sentence, i, **kwargs) for i in range(len(sentence)) ] @staticmethod def encode_string(s: str) -> str: """Encodes string in a form that crfsuite will accept (ASCII) and can be decoded""" return base64.b64encode(s.encode()).decode("ascii") @staticmethod def decode_string(s: str) -> str: """Decodes a string encoded by encode_string""" return base64.b64decode(s.encode("ascii")).decode() # ----------------------------------------------------------------------------- def train_model( conllu_path: typing.Union[str, Path], output_path: typing.Union[str, Path], label: str = "xpos", c1: float = 0.25, c2: float = 0.3, max_iterations: int = 100, ): """Train a new model from CONLLU data""" try: import conllu except ImportError as e: _LOGGER.critical("conllu package is required for training") _LOGGER.critical("pip install 'conllu>=4.4'") raise e conllu_path = Path(conllu_path) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) _LOGGER.debug("Loading train file (%s)", conllu_path) with open(conllu_path, "r", encoding="utf-8") as conllu_file: train_sents = conllu.parse(conllu_file.read()) _LOGGER.debug("Training model for %s max iteration(s)", max_iterations) trainer = pycrfsuite.Trainer(verbose=False) _LOGGER.debug("Getting features for %s training sentence(s)", len(train_sents)) for sent in train_sents: words = [token["form"] for token in sent] features = PartOfSpeechTagger.sent2features(words) labels = [] skip_sent = False for token in sent: token_label = token.get(label) if token_label is None: _LOGGER.warning("Example has empty label for %s: %s", token, sent) skip_sent = True break labels.append(token_label) if skip_sent: continue trainer.append(features, labels) trainer.set_params( { "c1": c1, # coefficient for L1 penalty "c2": c2, # coefficient for L2 penalty "max_iterations": max_iterations, # stop earlier # include transitions that are possible, but not observed "feature.possible_transitions": True, } ) _LOGGER.debug(trainer.get_params()) # Begin training _LOGGER.info("Training") start_time = time.perf_counter() trainer.train(str(output_path)) end_time = time.perf_counter() _LOGGER.info("Training completed in %s second(s)", end_time - start_time) _LOGGER.info(trainer.logparser.last_iteration) def do_train(args): """CLI method for train_model""" train_model( conllu_path=args.conllu, output_path=args.output, label=args.label, c1=args.c1, c2=args.c2, max_iterations=args.max_iterations, ) # ----------------------------------------------------------------------------- def do_print_labels(args): """Print label set from a CONLLU file""" try: import conllu except ImportError as e: _LOGGER.critical("conllu package is required for training") _LOGGER.critical("pip install 'conllu>=4.4'") raise e labels = set() with open(args.conllu, "r", encoding="utf-8") as conllu_file: for sent in conllu.parse(conllu_file.read()): for token in sent: token_label = token.get(args.label) if token_label is not None: labels.add(token_label) print(sorted(labels)) # ----------------------------------------------------------------------------- def do_predict(args): """CLI method for predict""" tagger = PartOfSpeechTagger(args.model) if args.texts: lines = args.texts else: lines = sys.stdin if os.isatty(sys.stdin.fileno()): print("Reading sentences from stdin...", file=sys.stderr) writer = jsonlines.Writer(sys.stdout, flush=True) for line in lines: line = line.strip() if not line: continue words = line.split() words_and_tags = list(zip(words, tagger(words))) writer.write(words_and_tags) def do_test(args): """CLI method for testing""" try: import conllu except ImportError as e: _LOGGER.critical("conllu package is required for training") _LOGGER.critical("pip install 'conllu>=4.4'") raise e tagger = PartOfSpeechTagger(args.model) _LOGGER.debug("Testing file (%s)", args.conllu) num_sentences = 0 num_words = 0 sents_with_errors = 0 total_errors = 0 with open(args.conllu, "r", encoding="utf-8") as conllu_file: for sent in conllu.parse(conllu_file.read()): words = [token["form"] for token in sent] actual_labels = [token.get(args.label) for token in sent] expected_labels = tagger(words) had_error = False for actual, expected in zip(actual_labels, expected_labels): if actual != expected: total_errors += 1 had_error = True num_words += 1 if had_error: sents_with_errors += 1 num_sentences += 1 if (num_sentences < 1) or (num_words < 1): return print( "{0} out of {1} word(s) had an incorrect tag ({2:0.2f}%)".format( total_errors, num_words, total_errors / num_words ) ) print( "{0} out of {1} sentence(s) had at least one error ({2:0.2f}%)".format( sents_with_errors, num_sentences, sents_with_errors / num_sentences ) ) # ----------------------------------------------------------------------------- def main(): """Main entry point""" parser = argparse.ArgumentParser(prog="pos.py") # Create subparsers for each sub-command sub_parsers = parser.add_subparsers() sub_parsers.required = True sub_parsers.dest = "command" # ----- # Train # ----- train_parser = sub_parsers.add_parser( "train", help="Train a new POS model from CONLLU file" ) train_parser.add_argument( "--conllu", required=True, help="CONLLU file with training data" ) train_parser.add_argument( "--output", required=True, help="Path to write output model" ) train_parser.add_argument( "--label", default="xpos", help="Field to predict in training data" ) train_parser.add_argument("--c1", type=float, default=0.25, help="L1 penalty") train_parser.add_argument("--c2", type=float, default=0.3, help="L2 penalty") train_parser.add_argument( "--max-iterations", type=int, default=100, help="Maximum number of iterations to train for", ) train_parser.set_defaults(func=do_train) # ---- # Test # ---- test_parser = sub_parsers.add_parser( "test", help="Test a POS model on a CONLLU file" ) test_parser.add_argument("--model", required=True, help="Path to POS tagger model") test_parser.add_argument( "--conllu", required=True, help="CONLLU file with testing data" ) test_parser.add_argument( "--label", default="xpos", help="Field to predict in training data" ) test_parser.set_defaults(func=do_test) # ------------ # Print Labels # ------------ print_labels_parser = sub_parsers.add_parser( "print-labels", help="Print set of unique labels from a CONLLU file" ) print_labels_parser.add_argument( "--conllu", required=True, help="CONLLU file with training data" ) print_labels_parser.add_argument( "--label", default="xpos", help="Field to predict in training data" ) print_labels_parser.set_defaults(func=do_print_labels) # ------- # Predict # ------- predict_parser = sub_parsers.add_parser( "predict", help="Predict POS tags for sentence(s)" ) predict_parser.add_argument( "--model", required=True, help="Path to POS tagger model" ) predict_parser.add_argument("texts", nargs="*", help="Sentences") predict_parser.set_defaults(func=do_predict) # ---------------- # Shared arguments # ---------------- for sub_parser in [train_parser, predict_parser, test_parser, print_labels_parser]: sub_parser.add_argument( "--debug", action="store_true", help="Print DEBUG messages to console" ) args = parser.parse_args() if args.debug: logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.INFO) _LOGGER.debug(args) args.func(args) if __name__ == "__main__": main()