ai-content-maker/.venv/Lib/site-packages/gruut/utils.py

355 lines
10 KiB
Python

"""Utility methods for gruut"""
import itertools
import logging
import os
import re
import ssl
import typing
import xml.etree.ElementTree as etree
from pathlib import Path
from urllib.request import urlopen
import networkx as nx
from gruut_ipa import IPA
from gruut.const import (
DATA_PROP,
LANG_ALIASES,
NODE_TYPE,
EndElement,
GraphType,
InlineLexicon,
Lexeme,
Node,
WordRole,
)
from gruut.resources import _DIR
_LOGGER = logging.getLogger("gruut.utils")
# -----------------------------------------------------------------------------
# Language utilities
# -----------------------------------------------------------------------------
LANG_SPLIT_PATTERN = re.compile(r"[-_]")
def resolve_lang(lang: str) -> str:
"""
Try to resolve language using aliases.
Args:
lang: Language name or alias
Returns:
Resolved language name
"""
lang = lang.lower().replace("_", "-")
return LANG_ALIASES.get(lang, lang)
def find_lang_dir(
lang: str,
search_dirs: typing.Optional[typing.Iterable[typing.Union[str, Path]]] = None,
) -> typing.Optional[Path]:
"""
Search for a language's model directory by name.
Tries to find a directory by:
#. Importing a module name ``gruut_lang_<short_lang>`` where short_lang is "en" for "en-us", etc.
#. Looking for ``<lang>/lexicon.db`` in each directory in order:
* ``search_dirs``
* ``$XDG_CONFIG_HOME/gruut``
* A "data" directory next to the gruut module
Args:
lang: Full language name (e.g., en-us)
search_dirs: Optional iterable of directory paths to search first
Returns:
Path to the language model directory or None if it can't be found
"""
base_lang = LANG_SPLIT_PATTERN.split(lang)[0].lower()
lang_module_name = f"gruut_lang_{base_lang}"
try:
lang_module = __import__(lang_module_name)
_LOGGER.debug("(%s) successfully imported %s", lang, lang_module_name)
return lang_module.get_lang_dir()
except ImportError:
_LOGGER.debug("(%s) couldn't import module %s", lang, lang_module_name)
pass
search_dirs = typing.cast(typing.List[Path], [Path(p) for p in search_dirs or []])
# ${XDG_CONFIG_HOME}/gruut or ${HOME}/gruut
maybe_config_home = os.environ.get("XDG_CONFIG_HOME")
if maybe_config_home:
search_dirs.append(Path(maybe_config_home) / "gruut")
else:
search_dirs.append(Path.home() / ".config" / "gruut")
# Data directory *next to* gruut
search_dirs.append(_DIR.parent / "data")
_LOGGER.debug("(%s) searching %s for language file(s)", lang, search_dirs)
for check_dir in search_dirs:
lang_dir = check_dir / lang
lexicon_path = lang_dir / "lexicon.db"
if lexicon_path.is_file():
_LOGGER.debug("(%s) found language file(s) in %s", lang, lang_dir)
return lang_dir
return None
# -----------------------------------------------------------------------------
# Babel
# -----------------------------------------------------------------------------
def get_currency_names(locale_str: str) -> typing.Dict[str, str]:
"""
Try to get currency names and symbols for a Babel locale.
Returns:
Dictionary whose keys are currency symbols (like "$") and whose values are currency names (like "USD")
"""
currency_names = {}
try:
import babel
import babel.numbers
locale = babel.Locale(locale_str)
currency_names = {
babel.numbers.get_currency_symbol(cn): cn for cn in locale.currency_symbols
}
except ImportError:
# Expected if babel is not installed
pass
except Exception:
_LOGGER.warning("get_currency_names")
return currency_names
# -----------------------------------------------------------------------------
# Iteration
# -----------------------------------------------------------------------------
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)
def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return itertools.zip_longest(*args, fillvalue=fillvalue)
def sliding_window(iterable, n=2):
"""Returns a sliding window of size n over an iterable"""
iterables = itertools.tee(iterable, n)
for win_iter, num_skipped in zip(iterables, itertools.count()):
for _ in range(num_skipped):
next(win_iter, None)
return zip(*iterables)
# -----------------------------------------------------------------------------
# XML
# -----------------------------------------------------------------------------
NO_NAMESPACE_PATTERN = re.compile(r"^{[^}]+}")
def tag_no_namespace(tag: str) -> str:
"""Remove namespace from XML tag"""
return NO_NAMESPACE_PATTERN.sub("", tag)
def attrib_no_namespace(
element: etree.Element, name: str, default: typing.Any = None
) -> typing.Any:
"""Search for an attribute by key without namespaces"""
for key, value in element.attrib.items():
key_no_ns = NO_NAMESPACE_PATTERN.sub("", key)
if key_no_ns == name:
return value
return default
def text_and_elements(element, is_last=False):
"""Yields element, text, sub-elements, end element, and tail"""
element_metadata = None
if is_last:
# True if this is the last child element of a parent.
# Used to preserve whitespace.
element_metadata = {"is_last": True}
yield element, element_metadata
# Text before any tags (or end tag)
text = element.text if element.text is not None else ""
if text.strip():
yield text
children = list(element)
last_child_idx = len(children) - 1
for child_idx, child in enumerate(children):
# Sub-elements
is_last = child_idx == last_child_idx
yield from text_and_elements(child, is_last=is_last)
# End of current element
yield EndElement(element)
# Text after the current tag
tail = element.tail if element.tail is not None else ""
if tail.strip():
yield tail
def load_lexicon(
uri: str,
lexicon: InlineLexicon,
ssl_context: typing.Optional[ssl.SSLContext] = None,
):
"""Loads a pronunciation lexicon from a URI"""
if ssl_context is None:
ssl_context = ssl.create_default_context()
with urlopen(uri, context=ssl_context) as response:
tree = etree.parse(response)
for lexeme_elem in tree.getroot():
if tag_no_namespace(lexeme_elem.tag) != "lexeme":
continue
lexeme = Lexeme()
role_str = attrib_no_namespace(lexeme_elem, "role")
if role_str:
lexeme.roles = set(role_str.strip().split())
for lexeme_child in lexeme_elem:
child_tag = tag_no_namespace(lexeme_child.tag)
if child_tag == "grapheme":
if lexeme_child.text:
lexeme.grapheme = lexeme_child.text.strip()
elif child_tag == "phoneme":
if lexeme_child.text:
lexeme.phonemes = maybe_split_ipa(lexeme_child.text.strip())
if lexeme.grapheme and lexeme.phonemes:
role_phonemes = lexicon.words.get(lexeme.grapheme)
if role_phonemes is None:
role_phonemes = {}
lexicon.words[lexeme.grapheme] = role_phonemes
assert role_phonemes is not None
roles = lexeme.roles or [WordRole.DEFAULT]
for role in roles:
role_phonemes[role] = lexeme.phonemes
# -----------------------------------------------------------------------------
# Text
# -----------------------------------------------------------------------------
NON_WORDS_PATTERN = re.compile(r"\W")
def remove_non_word_chars(s: str) -> str:
"""Removes non-word characters from a string"""
return NON_WORDS_PATTERN.sub("", s)
def maybe_split_ipa(s: str) -> typing.List[str]:
"""Split on whitespace if a space is present, otherwise return string as list of graphemes"""
if " " in s:
# Manual separation
return s.split()
# Automatic separation
return IPA.graphemes(s)
# -----------------------------------------------------------------------------
# Graph
# -----------------------------------------------------------------------------
def print_graph(
graph: GraphType,
node: typing.Union[NODE_TYPE, Node],
indent: str = "--",
level: int = 1,
print_func=print,
):
"""Prints a graph to the console"""
if isinstance(node, Node):
n_data = node
graph_node = node.node
else:
graph_node = node
n_data = typing.cast(Node, graph.nodes[graph_node][DATA_PROP])
print_func(indent * level, graph_node, n_data)
for succ_node in graph.successors(graph_node):
print_graph(
graph, succ_node, indent=indent, level=level + 1, print_func=print_func
)
def leaves(graph: GraphType, node: Node):
"""Iterate through the leaves of a graph in depth-first order"""
for dfs_node in nx.dfs_preorder_nodes(graph, node.node):
if not graph.out_degree(dfs_node) == 0:
continue
yield graph.nodes[dfs_node][DATA_PROP]
def pipeline_split(split_func, graph: GraphType, parent_node: Node,) -> bool:
"""Splits leaf nodes of tree into zero or more sub-nodes"""
was_changed = False
for leaf_node in list(leaves(graph, parent_node)):
for node_class, node_kwargs in split_func(graph, leaf_node):
new_node = node_class(node=len(graph), **node_kwargs)
graph.add_node(new_node.node, data=new_node)
graph.add_edge(leaf_node.node, new_node.node)
was_changed = True
return was_changed
def pipeline_transform(transform_func, graph: GraphType, parent_node: Node,) -> bool:
"""Transforms leaves of tree with a custom function"""
was_changed = False
for leaf_node in list(leaves(graph, parent_node)):
if transform_func(graph, leaf_node):
was_changed = True
return was_changed