ai-content-maker/.venv/Lib/site-packages/gruut/g2p.py

459 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Grapheme to phoneme prediction using python CRF suite.
Training requires pre-aligned corpus in Phonetisaurus format.
https://github.com/AdolfVonKleist/Phonetisaurus
The format of this corpus is:
t}t e}ˈɛ s}s t}t
Each line contains a single word, with graphemes and phonemes separated by "}".
Multiple graphemes are separated by "|":
s|h}ʃ o|w}ˈ
The empty phoneme is "_":
w}w h}_ y}ˈaɪ
Example:
.. code-block:: sh
python3 -m gruut.g2p train --corpus g2p.corpus --output model.crf
Pre-trained models have the following settings:
* c1 = 0
* c2 = 1
* max-iterations = 100
"""
import argparse
import base64
import itertools
import logging
import os
import sys
import time
import typing
import unicodedata
from pathlib import Path
import pycrfsuite
_LOGGER = logging.getLogger("gruut.g2p")
# -----------------------------------------------------------------------------
FEATURES_TYPE = typing.Dict[str, typing.Union[str, bool, int, float]]
EPS_PHONEME = "_"
PHONEME_JOIN = "|"
class GraphemesToPhonemes:
"""Grapheme to phoneme CRF tagger"""
def __init__(
self,
crf_tagger: typing.Union[str, Path, pycrfsuite.Tagger],
eps_phoneme: str = EPS_PHONEME,
phoneme_join: str = PHONEME_JOIN,
):
if isinstance(crf_tagger, pycrfsuite.Tagger):
self.crf_tagger = crf_tagger
else:
# Load model
self.crf_tagger = pycrfsuite.Tagger()
self.crf_tagger.open(str(crf_tagger))
# Empty phoneme (dropped)
self.eps_phoneme = eps_phoneme
# String used to join multiple predicted phonemes
self.phoneme_join = phoneme_join
def __call__(self, word: str, normalize: bool = True) -> typing.Sequence[str]:
"""Guess phonemes for word"""
features = GraphemesToPhonemes.word2features(word, normalize=normalize)
coded_phonemes = self.crf_tagger.tag(features)
phonemes: typing.List[str] = []
for coded_ps in coded_phonemes:
decoded_ps = GraphemesToPhonemes.decode_string(coded_ps)
for p in decoded_ps.split(self.phoneme_join):
if p != self.eps_phoneme:
phonemes.append(p)
return phonemes
# -------------------------------------------------------------------------
@staticmethod
def word2features(
word: typing.Union[str, typing.List[str]], normalize: bool = True, **kwargs
):
"""Create feature dicts for all graphemes in a word"""
if normalize and isinstance(word, str):
# Combine characters
# See: https://docs.python.org/3/library/unicodedata.html#unicodedata.normalize
word = unicodedata.normalize("NFC", word)
return [
GraphemesToPhonemes.grapheme2features(word, i, **kwargs)
for i in range(len(word))
]
@staticmethod
def grapheme2features(
word: typing.Union[str, typing.Sequence[str]],
i: int,
add_begin: bool = True,
add_end: bool = True,
chars_backward: int = 3,
chars_forward: int = 3,
bias: float = 1.0,
encode: bool = True,
) -> FEATURES_TYPE:
"""Create feature dict for single grapheme"""
g = word[i]
num_g = len(word)
features: FEATURES_TYPE = {
"bias": bias,
"grapheme": GraphemesToPhonemes.encode_string(g) if encode else g,
}
if (i == 0) and add_begin:
features["begin"] = True
for j in range(1, chars_backward + 1):
if i >= j:
g_prev = word[i - j]
features[f"grapheme-{j}"] = (
GraphemesToPhonemes.encode_string(g_prev) if encode else g_prev
)
for j in range(1, chars_forward + 1):
if i < (num_g - j):
g_next = word[i + j]
features[f"grapheme+{j}"] = (
GraphemesToPhonemes.encode_string(g_next) if encode else g_next
)
if (i == (num_g - 1)) and add_end:
features["end"] = True
return features
@staticmethod
def encode_string(s: str) -> str:
"""Encodes string in a form that crfsuite will accept (ASCII) and can be decoded"""
return base64.b64encode(s.encode()).decode("ascii")
@staticmethod
def decode_string(s: str) -> str:
"""Decodes a string encoded by encode_string"""
return base64.b64decode(s.encode("ascii")).decode()
# -----------------------------------------------------------------------------
def train(
corpus_path: typing.Union[str, Path],
output_path: typing.Union[str, Path],
group_separator: str = "}",
item_separator: str = "|",
phoneme_join: str = PHONEME_JOIN,
eps_phoneme: str = EPS_PHONEME,
remove_phonemes: typing.Optional[typing.Iterable[str]] = None,
c1: float = 0.0,
c2: float = 1.0,
max_iterations: int = 100,
):
"""Train a new G2P model"""
corpus_path = Path(corpus_path)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
remove_phonemes = set(remove_phonemes or [])
trainer = pycrfsuite.Trainer(verbose=False)
with open(corpus_path, "r", encoding="utf-8") as corpus:
for i, line in enumerate(corpus):
line = line.strip()
if not line:
continue
# Parse into graphemes and phonemes
skip_line = False
parts = line.split()
aligned_g = []
aligned_p = []
for part in parts:
# Graphemes/phonemes are separated by }
gs_str, ps_str = part.split(group_separator, maxsplit=1)
# Multiple graphemes and phonemes are separated by |
gs = gs_str.split(item_separator)
ps = [
p for p in ps_str.split(item_separator) if p not in remove_phonemes
]
# Align graphemes and phonemes, allowing for empty phonemes only
for g1, p1 in itertools.zip_longest(gs, [ps], fillvalue=None):
if g1 is None:
skip_line = True
break
aligned_g.append(g1)
if p1:
aligned_p.append(phoneme_join.join(p1))
else:
aligned_p.append(eps_phoneme)
if skip_line:
break
if skip_line:
_LOGGER.warning(
"Failed to align line %s: %s (graphemes=%s, phonemes=%s)",
i + 1,
line,
gs,
ps,
)
continue
# Add example to trainer
try:
encoded_p = [GraphemesToPhonemes.encode_string(p) for p in aligned_p]
trainer.append(GraphemesToPhonemes.word2features(aligned_g), encoded_p)
except Exception as e:
_LOGGER.exception("graphemes=%s phonemes=%s", aligned_g, aligned_p)
raise e
trainer.set_params(
{
"c1": c1, # coefficient for L1 penalty
"c2": c2, # coefficient for L2 penalty
"max_iterations": max_iterations, # stop earlier
# include transitions that are possible, but not observed
"feature.possible_transitions": True,
}
)
_LOGGER.debug(trainer.get_params())
# Begin training
_LOGGER.info("Training")
start_time = time.perf_counter()
trainer.train(str(output_path))
end_time = time.perf_counter()
_LOGGER.info("Training completed in %s second(s)", end_time - start_time)
_LOGGER.info(trainer.logparser.last_iteration)
def do_train(args):
"""CLI method for train"""
train(
corpus_path=args.corpus,
output_path=args.output,
group_separator=args.group_separator,
item_separator=args.item_separator,
remove_phonemes=args.remove_phonemes,
c1=args.c1,
c2=args.c2,
max_iterations=args.max_iterations,
)
def do_predict(args):
"""CLI method for predict"""
tagger = GraphemesToPhonemes(args.model)
if args.texts:
lines = args.texts
else:
lines = sys.stdin
if os.isatty(sys.stdin.fileno()):
print("Reading words from stdin...", file=sys.stderr)
for line in lines:
line = line.strip()
if not line:
continue
word = line
phonemes = tagger(word)
print(word, *phonemes)
def do_test(args):
"""CLI method for test"""
try:
from rapidfuzz.string_metric import levenshtein
except ImportError as e:
_LOGGER.critical("rapidfuzz library is needed for levenshtein distance")
_LOGGER.critical("pip install 'rapidfuzz>=1.4.1'")
raise e
tagger = GraphemesToPhonemes(args.model)
# Load lexicon
if args.texts:
lines = args.texts
else:
lines = sys.stdin
if os.isatty(sys.stdin.fileno()):
print("Reading lexicon lines from stdin...", file=sys.stderr)
lexicon = {}
for line in lines:
line = line.strip()
if (not line) or (" " not in line):
continue
word, actual_phonemes = line.split(maxsplit=1)
lexicon[word] = actual_phonemes
# Predict phonemes
predicted_phonemes = {}
start_time = time.perf_counter()
for word in lexicon:
phonemes = tagger(word)
predicted_phonemes[word] = " ".join(phonemes)
end_time = time.perf_counter()
# Calculate PER
num_errors = 0
num_missing = 0
num_phonemes = 0
for word, actual_phonemes in lexicon.items():
expected_phonemes = predicted_phonemes.get(word, "")
if expected_phonemes:
distance = levenshtein(expected_phonemes, actual_phonemes)
num_errors += distance
num_phonemes += len(actual_phonemes)
else:
num_missing += 1
_LOGGER.warning("No pronunciation for %s", word)
assert num_phonemes > 0, "No phonemes were read"
# Calculate results
per = round(num_errors / num_phonemes, 2)
wps = round(len(predicted_phonemes) / (end_time - start_time), 2)
print("PER:", per, "Errors:", num_errors, "words/sec:", wps)
if num_missing > 0:
print("Total missing:", num_missing)
# -----------------------------------------------------------------------------
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(prog="g2p.py")
# Create subparsers for each sub-command
sub_parsers = parser.add_subparsers()
sub_parsers.required = True
sub_parsers.dest = "command"
# -----
# Train
# -----
train_parser = sub_parsers.add_parser(
"train", help="Train a new G2P model from a pre-aligned Phonetisaurus corpus"
)
train_parser.add_argument(
"--corpus", required=True, help="Path to aligned Phonetisaurus g2p corpus"
)
train_parser.add_argument(
"--output", required=True, help="Path to output tagger model"
)
train_parser.add_argument("--c1", type=float, default=0.0, help="L1 penalty")
train_parser.add_argument("--c2", type=float, default=1.0, help="L2 penalty")
train_parser.add_argument(
"--max-iterations",
type=int,
default=100,
help="Maximum number of training iterations (default: 100)",
)
train_parser.add_argument(
"--group-separator",
default="}",
help="Separator between graphemes and phonemes",
)
train_parser.add_argument(
"--item-separator", default="|", help="Separator between items in a group"
)
train_parser.add_argument(
"--remove-phonemes", nargs="*", help="Remove phonemes from examples"
)
train_parser.set_defaults(func=do_train)
# -------
# Predict
# -------
predict_parser = sub_parsers.add_parser(
"predict", help="Predict phonemes for word(s)"
)
predict_parser.add_argument(
"--model", required=True, help="Path to G2P tagger model"
)
predict_parser.add_argument("texts", nargs="*", help="Words")
predict_parser.set_defaults(func=do_predict)
# ----
# Test
# ----
test_parser = sub_parsers.add_parser("test", help="Test G2P model on a lexicon")
test_parser.add_argument("--model", required=True, help="Path to G2P tagger model")
test_parser.add_argument(
"texts", nargs="*", help="Lines with '<word> <phoneme> <phoneme> ...'"
)
test_parser.set_defaults(func=do_test)
# ----------------
# Shared arguments
# ----------------
for sub_parser in [train_parser, predict_parser, test_parser]:
sub_parser.add_argument(
"--debug", action="store_true", help="Print DEBUG messages to console"
)
args = parser.parse_args()
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
_LOGGER.debug(args)
args.func(args)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main()