ai-content-maker/.venv/Lib/site-packages/thinc/layers/ragged2list.py

26 lines
773 B
Python

from typing import Callable, Tuple, TypeVar, cast
from ..config import registry
from ..model import Model
from ..types import ListXd, Ragged
InT = Ragged
OutT = TypeVar("OutT", bound=ListXd)
@registry.layers("ragged2list.v1")
def ragged2list() -> Model[InT, OutT]:
"""Transform sequences from a ragged format into lists."""
return Model("ragged2list", forward)
def forward(model: Model[InT, OutT], Xr: InT, is_train: bool) -> Tuple[OutT, Callable]:
lengths = Xr.lengths
def backprop(dXs: OutT) -> InT:
return Ragged(model.ops.flatten(dXs, pad=0), lengths) # type:ignore[arg-type]
# type ignore necessary for older versions of Mypy/Pydantic
data = cast(OutT, model.ops.unflatten(Xr.dataXd, Xr.lengths))
return data, backprop