459 lines
13 KiB
Python
459 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Grapheme to phoneme prediction using python CRF suite.
|
||
|
||
Training requires pre-aligned corpus in Phonetisaurus format.
|
||
https://github.com/AdolfVonKleist/Phonetisaurus
|
||
|
||
The format of this corpus is:
|
||
|
||
t}t e}ˈɛ s}s t}t
|
||
|
||
Each line contains a single word, with graphemes and phonemes separated by "}".
|
||
Multiple graphemes are separated by "|":
|
||
|
||
s|h}ʃ o|w}ˈoʊ
|
||
|
||
The empty phoneme is "_":
|
||
|
||
w}w h}_ y}ˈaɪ
|
||
|
||
Example:
|
||
|
||
.. code-block:: sh
|
||
|
||
python3 -m gruut.g2p train --corpus g2p.corpus --output model.crf
|
||
|
||
Pre-trained models have the following settings:
|
||
|
||
* c1 = 0
|
||
* c2 = 1
|
||
* max-iterations = 100
|
||
"""
|
||
import argparse
|
||
import base64
|
||
import itertools
|
||
import logging
|
||
import os
|
||
import sys
|
||
import time
|
||
import typing
|
||
import unicodedata
|
||
from pathlib import Path
|
||
|
||
import pycrfsuite
|
||
|
||
_LOGGER = logging.getLogger("gruut.g2p")
|
||
|
||
# -----------------------------------------------------------------------------
|
||
|
||
FEATURES_TYPE = typing.Dict[str, typing.Union[str, bool, int, float]]
|
||
EPS_PHONEME = "_"
|
||
PHONEME_JOIN = "|"
|
||
|
||
|
||
class GraphemesToPhonemes:
|
||
"""Grapheme to phoneme CRF tagger"""
|
||
|
||
def __init__(
|
||
self,
|
||
crf_tagger: typing.Union[str, Path, pycrfsuite.Tagger],
|
||
eps_phoneme: str = EPS_PHONEME,
|
||
phoneme_join: str = PHONEME_JOIN,
|
||
):
|
||
if isinstance(crf_tagger, pycrfsuite.Tagger):
|
||
self.crf_tagger = crf_tagger
|
||
else:
|
||
# Load model
|
||
self.crf_tagger = pycrfsuite.Tagger()
|
||
self.crf_tagger.open(str(crf_tagger))
|
||
|
||
# Empty phoneme (dropped)
|
||
self.eps_phoneme = eps_phoneme
|
||
|
||
# String used to join multiple predicted phonemes
|
||
self.phoneme_join = phoneme_join
|
||
|
||
def __call__(self, word: str, normalize: bool = True) -> typing.Sequence[str]:
|
||
"""Guess phonemes for word"""
|
||
features = GraphemesToPhonemes.word2features(word, normalize=normalize)
|
||
coded_phonemes = self.crf_tagger.tag(features)
|
||
phonemes: typing.List[str] = []
|
||
|
||
for coded_ps in coded_phonemes:
|
||
decoded_ps = GraphemesToPhonemes.decode_string(coded_ps)
|
||
for p in decoded_ps.split(self.phoneme_join):
|
||
if p != self.eps_phoneme:
|
||
phonemes.append(p)
|
||
|
||
return phonemes
|
||
|
||
# -------------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def word2features(
|
||
word: typing.Union[str, typing.List[str]], normalize: bool = True, **kwargs
|
||
):
|
||
"""Create feature dicts for all graphemes in a word"""
|
||
if normalize and isinstance(word, str):
|
||
# Combine characters
|
||
# See: https://docs.python.org/3/library/unicodedata.html#unicodedata.normalize
|
||
word = unicodedata.normalize("NFC", word)
|
||
|
||
return [
|
||
GraphemesToPhonemes.grapheme2features(word, i, **kwargs)
|
||
for i in range(len(word))
|
||
]
|
||
|
||
@staticmethod
|
||
def grapheme2features(
|
||
word: typing.Union[str, typing.Sequence[str]],
|
||
i: int,
|
||
add_begin: bool = True,
|
||
add_end: bool = True,
|
||
chars_backward: int = 3,
|
||
chars_forward: int = 3,
|
||
bias: float = 1.0,
|
||
encode: bool = True,
|
||
) -> FEATURES_TYPE:
|
||
"""Create feature dict for single grapheme"""
|
||
g = word[i]
|
||
num_g = len(word)
|
||
|
||
features: FEATURES_TYPE = {
|
||
"bias": bias,
|
||
"grapheme": GraphemesToPhonemes.encode_string(g) if encode else g,
|
||
}
|
||
|
||
if (i == 0) and add_begin:
|
||
features["begin"] = True
|
||
|
||
for j in range(1, chars_backward + 1):
|
||
if i >= j:
|
||
g_prev = word[i - j]
|
||
features[f"grapheme-{j}"] = (
|
||
GraphemesToPhonemes.encode_string(g_prev) if encode else g_prev
|
||
)
|
||
|
||
for j in range(1, chars_forward + 1):
|
||
if i < (num_g - j):
|
||
g_next = word[i + j]
|
||
features[f"grapheme+{j}"] = (
|
||
GraphemesToPhonemes.encode_string(g_next) if encode else g_next
|
||
)
|
||
|
||
if (i == (num_g - 1)) and add_end:
|
||
features["end"] = True
|
||
|
||
return features
|
||
|
||
@staticmethod
|
||
def encode_string(s: str) -> str:
|
||
"""Encodes string in a form that crfsuite will accept (ASCII) and can be decoded"""
|
||
return base64.b64encode(s.encode()).decode("ascii")
|
||
|
||
@staticmethod
|
||
def decode_string(s: str) -> str:
|
||
"""Decodes a string encoded by encode_string"""
|
||
return base64.b64decode(s.encode("ascii")).decode()
|
||
|
||
|
||
# -----------------------------------------------------------------------------
|
||
|
||
|
||
def train(
|
||
corpus_path: typing.Union[str, Path],
|
||
output_path: typing.Union[str, Path],
|
||
group_separator: str = "}",
|
||
item_separator: str = "|",
|
||
phoneme_join: str = PHONEME_JOIN,
|
||
eps_phoneme: str = EPS_PHONEME,
|
||
remove_phonemes: typing.Optional[typing.Iterable[str]] = None,
|
||
c1: float = 0.0,
|
||
c2: float = 1.0,
|
||
max_iterations: int = 100,
|
||
):
|
||
"""Train a new G2P model"""
|
||
corpus_path = Path(corpus_path)
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
remove_phonemes = set(remove_phonemes or [])
|
||
|
||
trainer = pycrfsuite.Trainer(verbose=False)
|
||
|
||
with open(corpus_path, "r", encoding="utf-8") as corpus:
|
||
for i, line in enumerate(corpus):
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
# Parse into graphemes and phonemes
|
||
skip_line = False
|
||
parts = line.split()
|
||
aligned_g = []
|
||
aligned_p = []
|
||
|
||
for part in parts:
|
||
# Graphemes/phonemes are separated by }
|
||
gs_str, ps_str = part.split(group_separator, maxsplit=1)
|
||
|
||
# Multiple graphemes and phonemes are separated by |
|
||
gs = gs_str.split(item_separator)
|
||
ps = [
|
||
p for p in ps_str.split(item_separator) if p not in remove_phonemes
|
||
]
|
||
|
||
# Align graphemes and phonemes, allowing for empty phonemes only
|
||
for g1, p1 in itertools.zip_longest(gs, [ps], fillvalue=None):
|
||
if g1 is None:
|
||
skip_line = True
|
||
break
|
||
|
||
aligned_g.append(g1)
|
||
|
||
if p1:
|
||
aligned_p.append(phoneme_join.join(p1))
|
||
else:
|
||
aligned_p.append(eps_phoneme)
|
||
|
||
if skip_line:
|
||
break
|
||
|
||
if skip_line:
|
||
_LOGGER.warning(
|
||
"Failed to align line %s: %s (graphemes=%s, phonemes=%s)",
|
||
i + 1,
|
||
line,
|
||
gs,
|
||
ps,
|
||
)
|
||
continue
|
||
|
||
# Add example to trainer
|
||
try:
|
||
encoded_p = [GraphemesToPhonemes.encode_string(p) for p in aligned_p]
|
||
trainer.append(GraphemesToPhonemes.word2features(aligned_g), encoded_p)
|
||
except Exception as e:
|
||
_LOGGER.exception("graphemes=%s phonemes=%s", aligned_g, aligned_p)
|
||
raise e
|
||
|
||
trainer.set_params(
|
||
{
|
||
"c1": c1, # coefficient for L1 penalty
|
||
"c2": c2, # coefficient for L2 penalty
|
||
"max_iterations": max_iterations, # stop earlier
|
||
# include transitions that are possible, but not observed
|
||
"feature.possible_transitions": True,
|
||
}
|
||
)
|
||
|
||
_LOGGER.debug(trainer.get_params())
|
||
|
||
# Begin training
|
||
_LOGGER.info("Training")
|
||
|
||
start_time = time.perf_counter()
|
||
trainer.train(str(output_path))
|
||
end_time = time.perf_counter()
|
||
|
||
_LOGGER.info("Training completed in %s second(s)", end_time - start_time)
|
||
_LOGGER.info(trainer.logparser.last_iteration)
|
||
|
||
|
||
def do_train(args):
|
||
"""CLI method for train"""
|
||
train(
|
||
corpus_path=args.corpus,
|
||
output_path=args.output,
|
||
group_separator=args.group_separator,
|
||
item_separator=args.item_separator,
|
||
remove_phonemes=args.remove_phonemes,
|
||
c1=args.c1,
|
||
c2=args.c2,
|
||
max_iterations=args.max_iterations,
|
||
)
|
||
|
||
|
||
def do_predict(args):
|
||
"""CLI method for predict"""
|
||
tagger = GraphemesToPhonemes(args.model)
|
||
|
||
if args.texts:
|
||
lines = args.texts
|
||
else:
|
||
lines = sys.stdin
|
||
|
||
if os.isatty(sys.stdin.fileno()):
|
||
print("Reading words from stdin...", file=sys.stderr)
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
word = line
|
||
phonemes = tagger(word)
|
||
|
||
print(word, *phonemes)
|
||
|
||
|
||
def do_test(args):
|
||
"""CLI method for test"""
|
||
try:
|
||
from rapidfuzz.string_metric import levenshtein
|
||
except ImportError as e:
|
||
_LOGGER.critical("rapidfuzz library is needed for levenshtein distance")
|
||
_LOGGER.critical("pip install 'rapidfuzz>=1.4.1'")
|
||
raise e
|
||
|
||
tagger = GraphemesToPhonemes(args.model)
|
||
|
||
# Load lexicon
|
||
if args.texts:
|
||
lines = args.texts
|
||
else:
|
||
lines = sys.stdin
|
||
|
||
if os.isatty(sys.stdin.fileno()):
|
||
print("Reading lexicon lines from stdin...", file=sys.stderr)
|
||
|
||
lexicon = {}
|
||
for line in lines:
|
||
line = line.strip()
|
||
if (not line) or (" " not in line):
|
||
continue
|
||
|
||
word, actual_phonemes = line.split(maxsplit=1)
|
||
lexicon[word] = actual_phonemes
|
||
|
||
# Predict phonemes
|
||
predicted_phonemes = {}
|
||
|
||
start_time = time.perf_counter()
|
||
|
||
for word in lexicon:
|
||
phonemes = tagger(word)
|
||
predicted_phonemes[word] = " ".join(phonemes)
|
||
|
||
end_time = time.perf_counter()
|
||
|
||
# Calculate PER
|
||
num_errors = 0
|
||
num_missing = 0
|
||
num_phonemes = 0
|
||
|
||
for word, actual_phonemes in lexicon.items():
|
||
expected_phonemes = predicted_phonemes.get(word, "")
|
||
|
||
if expected_phonemes:
|
||
distance = levenshtein(expected_phonemes, actual_phonemes)
|
||
num_errors += distance
|
||
num_phonemes += len(actual_phonemes)
|
||
else:
|
||
num_missing += 1
|
||
_LOGGER.warning("No pronunciation for %s", word)
|
||
|
||
assert num_phonemes > 0, "No phonemes were read"
|
||
|
||
# Calculate results
|
||
per = round(num_errors / num_phonemes, 2)
|
||
wps = round(len(predicted_phonemes) / (end_time - start_time), 2)
|
||
print("PER:", per, "Errors:", num_errors, "words/sec:", wps)
|
||
|
||
if num_missing > 0:
|
||
print("Total missing:", num_missing)
|
||
|
||
|
||
# -----------------------------------------------------------------------------
|
||
|
||
|
||
def main():
|
||
"""Main entry point"""
|
||
parser = argparse.ArgumentParser(prog="g2p.py")
|
||
|
||
# Create subparsers for each sub-command
|
||
sub_parsers = parser.add_subparsers()
|
||
sub_parsers.required = True
|
||
sub_parsers.dest = "command"
|
||
|
||
# -----
|
||
# Train
|
||
# -----
|
||
train_parser = sub_parsers.add_parser(
|
||
"train", help="Train a new G2P model from a pre-aligned Phonetisaurus corpus"
|
||
)
|
||
train_parser.add_argument(
|
||
"--corpus", required=True, help="Path to aligned Phonetisaurus g2p corpus"
|
||
)
|
||
train_parser.add_argument(
|
||
"--output", required=True, help="Path to output tagger model"
|
||
)
|
||
train_parser.add_argument("--c1", type=float, default=0.0, help="L1 penalty")
|
||
train_parser.add_argument("--c2", type=float, default=1.0, help="L2 penalty")
|
||
train_parser.add_argument(
|
||
"--max-iterations",
|
||
type=int,
|
||
default=100,
|
||
help="Maximum number of training iterations (default: 100)",
|
||
)
|
||
train_parser.add_argument(
|
||
"--group-separator",
|
||
default="}",
|
||
help="Separator between graphemes and phonemes",
|
||
)
|
||
train_parser.add_argument(
|
||
"--item-separator", default="|", help="Separator between items in a group"
|
||
)
|
||
train_parser.add_argument(
|
||
"--remove-phonemes", nargs="*", help="Remove phonemes from examples"
|
||
)
|
||
train_parser.set_defaults(func=do_train)
|
||
|
||
# -------
|
||
# Predict
|
||
# -------
|
||
predict_parser = sub_parsers.add_parser(
|
||
"predict", help="Predict phonemes for word(s)"
|
||
)
|
||
predict_parser.add_argument(
|
||
"--model", required=True, help="Path to G2P tagger model"
|
||
)
|
||
predict_parser.add_argument("texts", nargs="*", help="Words")
|
||
predict_parser.set_defaults(func=do_predict)
|
||
|
||
# ----
|
||
# Test
|
||
# ----
|
||
test_parser = sub_parsers.add_parser("test", help="Test G2P model on a lexicon")
|
||
test_parser.add_argument("--model", required=True, help="Path to G2P tagger model")
|
||
test_parser.add_argument(
|
||
"texts", nargs="*", help="Lines with '<word> <phoneme> <phoneme> ...'"
|
||
)
|
||
test_parser.set_defaults(func=do_test)
|
||
|
||
# ----------------
|
||
# Shared arguments
|
||
# ----------------
|
||
for sub_parser in [train_parser, predict_parser, test_parser]:
|
||
sub_parser.add_argument(
|
||
"--debug", action="store_true", help="Print DEBUG messages to console"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.debug:
|
||
logging.basicConfig(level=logging.DEBUG)
|
||
else:
|
||
logging.basicConfig(level=logging.INFO)
|
||
|
||
_LOGGER.debug(args)
|
||
|
||
args.func(args)
|
||
|
||
|
||
# -----------------------------------------------------------------------------
|
||
|
||
if __name__ == "__main__":
|
||
main()
|