ai-content-maker/.venv/Lib/site-packages/torchaudio/pipelines/_wav2vec2/aligner.py

88 lines
2.6 KiB
Python

from abc import ABC, abstractmethod
from typing import Dict, List
import torch
import torchaudio.functional as F
from torch import Tensor
from torchaudio.functional import TokenSpan
class ITokenizer(ABC):
@abstractmethod
def __call__(self, transcript: List[str]) -> List[List[str]]:
"""Tokenize the given transcript (list of word)
.. note::
The toranscript must be normalized.
Args:
transcript (list of str): Transcript (list of word).
Returns:
(list of int): List of token sequences
"""
class Tokenizer(ITokenizer):
def __init__(self, dictionary: Dict[str, int]):
self.dictionary = dictionary
def __call__(self, transcript: List[str]) -> List[List[int]]:
return [[self.dictionary[c] for c in word] for word in transcript]
def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
device = emission.device
emission = emission.unsqueeze(0)
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)
scores = scores.exp() # convert back to probability
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
return aligned_tokens, scores
class IAligner(ABC):
@abstractmethod
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
"""Generate list of time-stamped token sequences
Args:
emission (Tensor): Sequence of token probability distributions in log-domain.
Shape: `(time, tokens)`.
tokens (list of integer sequence): Tokenized transcript.
Output from :py:class:`torchaudio.pipelines.Wav2Vec2FABundle.Tokenizer`.
Returns:
(list of TokenSpan sequence): Tokens with time stamps and scores.
"""
def _unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret
def _flatten(nested_list):
return [item for list_ in nested_list for item in list_]
class Aligner(IAligner):
def __init__(self, blank):
self.blank = blank
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
if emission.ndim != 2:
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
spans = F.merge_tokens(aligned_tokens, scores)
return _unflatten(spans, [len(ts) for ts in tokens])