#!/usr/bin/env python3 """Guess word pronunciations using a Phonetisaurus FST See bin/fst2npz.py to convert an FST to a numpy graph. """ import argparse import logging import os import sys import time import typing from collections import defaultdict from pathlib import Path import numpy as np _LOGGER = logging.getLogger("g2p_phonetisaurus") NUMPY_GRAPH = typing.Dict[str, np.ndarray] # ----------------------------------------------------------------------------- def main(): """Main entry point""" parser = argparse.ArgumentParser(prog="g2p_phonetisaurus") # Create subparsers for each sub-command sub_parsers = parser.add_subparsers() sub_parsers.required = True sub_parsers.dest = "command" # ------- # Predict # ------- predict_parser = sub_parsers.add_parser( "predict", help="Predict phonemes for word(s)" ) predict_parser.add_argument( "--graph", required=True, help="Path to graph npz file from fst2npy.py" ) predict_parser.add_argument( "words", nargs="*", help="Words to guess pronunciations for" ) predict_parser.add_argument( "--max-guesses", default=1, type=int, help="Maximum number of guesses per word (default: 1)", ) predict_parser.add_argument( "--beam", default=500, type=int, help="Initial width of search beam (default: 500)", ) predict_parser.add_argument( "--min-beam", default=100, type=int, help="Minimum width of search beam (default: 100)", ) predict_parser.add_argument( "--beam-scale", default=0.6, type=float, help="Scalar multiplied by beam after each step (default: 0.6)", ) predict_parser.add_argument( "--grapheme-separator", default="", help="Separator between input graphemes (default: none)", ) predict_parser.add_argument( "--phoneme-separator", default=" ", help="Separator between output phonemes (default: space)", ) predict_parser.add_argument( "--preload-graph", action="store_true", help="Preload graph into memory before starting", ) predict_parser.set_defaults(func=do_predict) # ---- # Test # ---- test_parser = sub_parsers.add_parser("test", help="Test G2P model on a lexicon") test_parser.add_argument( "--graph", required=True, help="Path to graph npz file from fst2npy.py" ) test_parser.add_argument( "texts", nargs="*", help="Lines with ' ...'" ) test_parser.add_argument( "--beam", default=500, type=int, help="Initial width of search beam (default: 500)", ) test_parser.add_argument( "--min-beam", default=100, type=int, help="Minimum width of search beam (default: 100)", ) test_parser.add_argument( "--beam-scale", default=0.6, type=float, help="Scalar multiplied by beam after each step (default: 0.6)", ) test_parser.add_argument( "--preload-graph", action="store_true", help="Preload graph into memory before starting", ) test_parser.set_defaults(func=do_test) # ---------------- # Shared arguments # ---------------- for sub_parser in [predict_parser, test_parser]: sub_parser.add_argument( "--debug", action="store_true", help="Print DEBUG messages to console" ) args = parser.parse_args() if args.debug: logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.INFO) _LOGGER.debug(args) args.func(args) # ----------------------------------------------------------------------------- def do_predict(args): """Predict phonemes for words""" args.graph = Path(args.graph) _LOGGER.debug("Loading graph from %s", args.graph) phon_graph = PhonetisaurusGraph.load(args.graph, preload=args.preload_graph) if args.words: # Arguments words = args.words _LOGGER.info("Guessing pronunciations for %s word(s)", len(words)) else: # Standard input words = sys.stdin if os.isatty(sys.stdin.fileno()): print("Reading words from stdin...", file=sys.stderr) # Guess pronunciations for word, graphemes, phonemes in phon_graph.g2p( words, grapheme_separator=args.grapheme_separator, max_guesses=args.max_guesses, beam=args.beam, min_beam=args.min_beam, beam_scale=args.beam_scale, ): if not phonemes: _LOGGER.warning("No pronunciation for %s (%s)", word, graphemes) continue print(word, args.phoneme_separator.join(phonemes)) # ----------------------------------------------------------------------------- def do_test(args): """Test performance relative a known lexicon""" try: from rapidfuzz.string_metric import levenshtein except ImportError as e: _LOGGER.critical("rapidfuzz library is needed for levenshtein distance") _LOGGER.critical("pip install 'rapidfuzz>=1.4.1'") raise e args.graph = Path(args.graph) _LOGGER.debug("Loading graph from %s", args.graph) phon_graph = PhonetisaurusGraph.load(args.graph, preload=args.preload_graph) if args.texts: lines = args.texts else: lines = sys.stdin if os.isatty(sys.stdin.fileno()): print("Reading lexicon lines from stdin...", file=sys.stderr) # Load lexicon lexicon = {} for line in lines: line = line.strip() if (not line) or (" " not in line): continue word, actual_phonemes = line.split(maxsplit=1) lexicon[word] = actual_phonemes # Predict phonemes predicted_phonemes = {} start_time = time.perf_counter() for word in lexicon: for _, _, guessed_phonemes in phon_graph.g2p( [word], beam=args.beam, min_beam=args.min_beam, beam_scale=args.beam_scale, max_guesses=1, ): predicted_phonemes[word] = " ".join(guessed_phonemes) # Only one guess break end_time = time.perf_counter() # Calculate PER num_errors = 0 num_missing = 0 num_phonemes = 0 for word, actual_phonemes in lexicon.items(): expected_phonemes = predicted_phonemes.get(word, "") if expected_phonemes: distance = levenshtein(expected_phonemes, actual_phonemes) num_errors += distance num_phonemes += len(actual_phonemes) else: num_missing += 1 _LOGGER.warning("No pronunciation for %s", word) assert num_phonemes > 0, "No phonemes were read" # Calculate results per = round(num_errors / num_phonemes, 2) wps = round(len(predicted_phonemes) / (end_time - start_time), 2) print("PER:", per, "Errors:", num_errors, "words/sec:", wps) if num_missing > 0: print("Total missing:", num_missing) # ----------------------------------------------------------------------------- _NOT_FINAL = object() class PhonetisaurusGraph: """Graph of numpy arrays that represents a Phonetisaurus FST Also contains shared cache of edges and final state probabilities. These caches are necessary to ensure that the .npz file stays small and fast to load. """ def __init__(self, graph: NUMPY_GRAPH, preload: bool = False): self.graph = graph self.start_node = int(self.graph["start_node"].item()) # edge_index -> (from_node, to_node, ilabel, olabel) self.edges = self.graph["edges"] self.edge_probs = self.graph["edge_probs"] # int -> [str] self.symbols = [] for symbol_str in self.graph["symbols"]: symbol_list = symbol_str.replace("_", "").split("|") self.symbols.append((len(symbol_list), symbol_list)) # nodes that are accepting states self.final_nodes = self.graph["final_nodes"] # node -> probability self.final_probs = self.graph["final_probs"] # Cache self.preloaded = preload self.out_edges: typing.Dict[int, typing.List[int]] = defaultdict(list) self.final_node_probs: typing.Dict[int, typing.Any] = {} if preload: # Load out edges for edge_idx, (from_node, *_) in enumerate(self.edges): self.out_edges[from_node].append(edge_idx) # Load final probabilities self.final_node_probs.update(zip(self.final_nodes, self.final_probs)) @staticmethod def load(graph_path: typing.Union[str, Path], **kwargs) -> "PhonetisaurusGraph": """Load .npz file with numpy graph""" np_graph = np.load(graph_path, allow_pickle=True) return PhonetisaurusGraph(np_graph, **kwargs) def g2p( self, words: typing.Iterable[typing.Union[str, typing.Sequence[str]]], **kwargs ) -> typing.Iterable[ typing.Tuple[ typing.Union[str, typing.Sequence[str]], typing.Sequence[str], typing.Sequence[str], ], ]: """Guess phonemes for words""" for word in words: for graphemes, phonemes in self.g2p_one(word, **kwargs): yield word, graphemes, phonemes def g2p_one( self, word: typing.Union[str, typing.Sequence[str]], eps: str = "", beam: int = 5000, min_beam: int = 100, beam_scale: float = 0.6, grapheme_separator: str = "", max_guesses: int = 1, ) -> typing.Iterable[typing.Tuple[typing.Sequence[str], typing.Sequence[str]]]: """Guess phonemes for word""" current_beam = beam graphemes: typing.Sequence[str] = [] if isinstance(word, str): word = word.strip() if grapheme_separator: graphemes = word.split(grapheme_separator) else: graphemes = list(word) else: graphemes = word if not graphemes: return graphemes, [] # (prob, node, graphemes, phonemes, final, beam) q: typing.List[ typing.Tuple[ float, typing.Optional[int], typing.Sequence[str], typing.List[str], bool, ] ] = [(0.0, self.start_node, graphemes, [], False)] q_next: typing.List[ typing.Tuple[ float, typing.Optional[int], typing.Sequence[str], typing.List[str], bool, ] ] = [] # (prob, phonemes) best_heap: typing.List[typing.Tuple[float, typing.Sequence[str]]] = [] # Avoid duplicate guesses guessed_phonemes: typing.Set[typing.Tuple[str, ...]] = set() while q: done_with_word = False q_next = [] for prob, node, next_graphemes, output, is_final in q: if is_final: # Complete guess phonemes = tuple(output) if phonemes not in guessed_phonemes: best_heap.append((prob, phonemes)) guessed_phonemes.add(phonemes) if len(best_heap) >= max_guesses: done_with_word = True break continue assert node is not None if not next_graphemes: if self.preloaded: final_prob = self.final_node_probs.get(node, _NOT_FINAL) else: final_prob = self.final_node_probs.get(node) if final_prob is None: final_idx = int(np.searchsorted(self.final_nodes, node)) if self.final_nodes[final_idx] == node: # Cache final_prob = float(self.final_probs[final_idx]) self.final_node_probs[node] = final_prob else: # Not a final state final_prob = _NOT_FINAL self.final_node_probs[node] = final_prob if final_prob != _NOT_FINAL: final_prob = typing.cast(float, final_prob) q_next.append((prob + final_prob, None, [], output, True)) len_next_graphemes = len(next_graphemes) if self.preloaded: # Was pre-loaded in __init__ edge_idxs = self.out_edges[node] else: # Build cache during search maybe_edge_idxs = self.out_edges.get(node) if maybe_edge_idxs is None: edge_idx = int(np.searchsorted(self.edges[:, 0], node)) edge_idxs = [] while self.edges[edge_idx][0] == node: edge_idxs.append(edge_idx) edge_idx += 1 # Cache self.out_edges[node] = edge_idxs else: edge_idxs = maybe_edge_idxs for edge_idx in edge_idxs: _, to_node, ilabel_idx, olabel_idx = self.edges[edge_idx] out_prob = self.edge_probs[edge_idx] len_igraphemes, igraphemes = self.symbols[ilabel_idx] if len_igraphemes > len_next_graphemes: continue if igraphemes == [eps]: item = (prob + out_prob, to_node, next_graphemes, output, False) q_next.append(item) else: sub_graphemes = next_graphemes[:len_igraphemes] if igraphemes == sub_graphemes: _, olabel = self.symbols[olabel_idx] item = ( prob + out_prob, to_node, next_graphemes[len(sub_graphemes) :], output + olabel, False, ) q_next.append(item) if done_with_word: break q_next = sorted(q_next, key=lambda item: item[0])[:current_beam] q = q_next current_beam = max(min_beam, (int(current_beam * beam_scale))) # Yield guesses if best_heap: for _, guess_phonemes in sorted(best_heap, key=lambda item: item[0])[ :max_guesses ]: yield graphemes, [p for p in guess_phonemes if p] else: # No guesses yield graphemes, [] # ----------------------------------------------------------------------------- if __name__ == "__main__": main()