"""Utility methods for gruut""" import itertools import logging import os import re import ssl import typing import xml.etree.ElementTree as etree from pathlib import Path from urllib.request import urlopen import networkx as nx from gruut_ipa import IPA from gruut.const import ( DATA_PROP, LANG_ALIASES, NODE_TYPE, EndElement, GraphType, InlineLexicon, Lexeme, Node, WordRole, ) from gruut.resources import _DIR _LOGGER = logging.getLogger("gruut.utils") # ----------------------------------------------------------------------------- # Language utilities # ----------------------------------------------------------------------------- LANG_SPLIT_PATTERN = re.compile(r"[-_]") def resolve_lang(lang: str) -> str: """ Try to resolve language using aliases. Args: lang: Language name or alias Returns: Resolved language name """ lang = lang.lower().replace("_", "-") return LANG_ALIASES.get(lang, lang) def find_lang_dir( lang: str, search_dirs: typing.Optional[typing.Iterable[typing.Union[str, Path]]] = None, ) -> typing.Optional[Path]: """ Search for a language's model directory by name. Tries to find a directory by: #. Importing a module name ``gruut_lang_`` where short_lang is "en" for "en-us", etc. #. Looking for ``/lexicon.db`` in each directory in order: * ``search_dirs`` * ``$XDG_CONFIG_HOME/gruut`` * A "data" directory next to the gruut module Args: lang: Full language name (e.g., en-us) search_dirs: Optional iterable of directory paths to search first Returns: Path to the language model directory or None if it can't be found """ base_lang = LANG_SPLIT_PATTERN.split(lang)[0].lower() lang_module_name = f"gruut_lang_{base_lang}" try: lang_module = __import__(lang_module_name) _LOGGER.debug("(%s) successfully imported %s", lang, lang_module_name) return lang_module.get_lang_dir() except ImportError: _LOGGER.debug("(%s) couldn't import module %s", lang, lang_module_name) pass search_dirs = typing.cast(typing.List[Path], [Path(p) for p in search_dirs or []]) # ${XDG_CONFIG_HOME}/gruut or ${HOME}/gruut maybe_config_home = os.environ.get("XDG_CONFIG_HOME") if maybe_config_home: search_dirs.append(Path(maybe_config_home) / "gruut") else: search_dirs.append(Path.home() / ".config" / "gruut") # Data directory *next to* gruut search_dirs.append(_DIR.parent / "data") _LOGGER.debug("(%s) searching %s for language file(s)", lang, search_dirs) for check_dir in search_dirs: lang_dir = check_dir / lang lexicon_path = lang_dir / "lexicon.db" if lexicon_path.is_file(): _LOGGER.debug("(%s) found language file(s) in %s", lang, lang_dir) return lang_dir return None # ----------------------------------------------------------------------------- # Babel # ----------------------------------------------------------------------------- def get_currency_names(locale_str: str) -> typing.Dict[str, str]: """ Try to get currency names and symbols for a Babel locale. Returns: Dictionary whose keys are currency symbols (like "$") and whose values are currency names (like "USD") """ currency_names = {} try: import babel import babel.numbers locale = babel.Locale(locale_str) currency_names = { babel.numbers.get_currency_symbol(cn): cn for cn in locale.currency_symbols } except ImportError: # Expected if babel is not installed pass except Exception: _LOGGER.warning("get_currency_names") return currency_names # ----------------------------------------------------------------------------- # Iteration # ----------------------------------------------------------------------------- def pairwise(iterable): "s -> (s0,s1), (s1,s2), (s2, s3), ..." a, b = itertools.tee(iterable) next(b, None) return zip(a, b) def grouper(iterable, n, fillvalue=None): "Collect data into fixed-length chunks or blocks" # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" args = [iter(iterable)] * n return itertools.zip_longest(*args, fillvalue=fillvalue) def sliding_window(iterable, n=2): """Returns a sliding window of size n over an iterable""" iterables = itertools.tee(iterable, n) for win_iter, num_skipped in zip(iterables, itertools.count()): for _ in range(num_skipped): next(win_iter, None) return zip(*iterables) # ----------------------------------------------------------------------------- # XML # ----------------------------------------------------------------------------- NO_NAMESPACE_PATTERN = re.compile(r"^{[^}]+}") def tag_no_namespace(tag: str) -> str: """Remove namespace from XML tag""" return NO_NAMESPACE_PATTERN.sub("", tag) def attrib_no_namespace( element: etree.Element, name: str, default: typing.Any = None ) -> typing.Any: """Search for an attribute by key without namespaces""" for key, value in element.attrib.items(): key_no_ns = NO_NAMESPACE_PATTERN.sub("", key) if key_no_ns == name: return value return default def text_and_elements(element, is_last=False): """Yields element, text, sub-elements, end element, and tail""" element_metadata = None if is_last: # True if this is the last child element of a parent. # Used to preserve whitespace. element_metadata = {"is_last": True} yield element, element_metadata # Text before any tags (or end tag) text = element.text if element.text is not None else "" if text.strip(): yield text children = list(element) last_child_idx = len(children) - 1 for child_idx, child in enumerate(children): # Sub-elements is_last = child_idx == last_child_idx yield from text_and_elements(child, is_last=is_last) # End of current element yield EndElement(element) # Text after the current tag tail = element.tail if element.tail is not None else "" if tail.strip(): yield tail def load_lexicon( uri: str, lexicon: InlineLexicon, ssl_context: typing.Optional[ssl.SSLContext] = None, ): """Loads a pronunciation lexicon from a URI""" if ssl_context is None: ssl_context = ssl.create_default_context() with urlopen(uri, context=ssl_context) as response: tree = etree.parse(response) for lexeme_elem in tree.getroot(): if tag_no_namespace(lexeme_elem.tag) != "lexeme": continue lexeme = Lexeme() role_str = attrib_no_namespace(lexeme_elem, "role") if role_str: lexeme.roles = set(role_str.strip().split()) for lexeme_child in lexeme_elem: child_tag = tag_no_namespace(lexeme_child.tag) if child_tag == "grapheme": if lexeme_child.text: lexeme.grapheme = lexeme_child.text.strip() elif child_tag == "phoneme": if lexeme_child.text: lexeme.phonemes = maybe_split_ipa(lexeme_child.text.strip()) if lexeme.grapheme and lexeme.phonemes: role_phonemes = lexicon.words.get(lexeme.grapheme) if role_phonemes is None: role_phonemes = {} lexicon.words[lexeme.grapheme] = role_phonemes assert role_phonemes is not None roles = lexeme.roles or [WordRole.DEFAULT] for role in roles: role_phonemes[role] = lexeme.phonemes # ----------------------------------------------------------------------------- # Text # ----------------------------------------------------------------------------- NON_WORDS_PATTERN = re.compile(r"\W") def remove_non_word_chars(s: str) -> str: """Removes non-word characters from a string""" return NON_WORDS_PATTERN.sub("", s) def maybe_split_ipa(s: str) -> typing.List[str]: """Split on whitespace if a space is present, otherwise return string as list of graphemes""" if " " in s: # Manual separation return s.split() # Automatic separation return IPA.graphemes(s) # ----------------------------------------------------------------------------- # Graph # ----------------------------------------------------------------------------- def print_graph( graph: GraphType, node: typing.Union[NODE_TYPE, Node], indent: str = "--", level: int = 1, print_func=print, ): """Prints a graph to the console""" if isinstance(node, Node): n_data = node graph_node = node.node else: graph_node = node n_data = typing.cast(Node, graph.nodes[graph_node][DATA_PROP]) print_func(indent * level, graph_node, n_data) for succ_node in graph.successors(graph_node): print_graph( graph, succ_node, indent=indent, level=level + 1, print_func=print_func ) def leaves(graph: GraphType, node: Node): """Iterate through the leaves of a graph in depth-first order""" for dfs_node in nx.dfs_preorder_nodes(graph, node.node): if not graph.out_degree(dfs_node) == 0: continue yield graph.nodes[dfs_node][DATA_PROP] def pipeline_split(split_func, graph: GraphType, parent_node: Node,) -> bool: """Splits leaf nodes of tree into zero or more sub-nodes""" was_changed = False for leaf_node in list(leaves(graph, parent_node)): for node_class, node_kwargs in split_func(graph, leaf_node): new_node = node_class(node=len(graph), **node_kwargs) graph.add_node(new_node.node, data=new_node) graph.add_edge(leaf_node.node, new_node.node) was_changed = True return was_changed def pipeline_transform(transform_func, graph: GraphType, parent_node: Node,) -> bool: """Transforms leaves of tree with a custom function""" was_changed = False for leaf_node in list(leaves(graph, parent_node)): if transform_func(graph, leaf_node): was_changed = True return was_changed