ai-content-maker/.venv/Lib/site-packages/thinc/layers/with_array.py

111 lines
3.5 KiB
Python

from typing import Callable, Optional, Tuple, TypeVar, Union, cast
from ..backends import NumpyOps
from ..config import registry
from ..model import Model
from ..types import Array3d, ArrayXd, ListXd, Padded, Ragged
NUMPY_OPS = NumpyOps()
ArrayTXd = TypeVar("ArrayTXd", bound=ArrayXd)
SeqT = TypeVar("SeqT", bound=Union[Padded, Ragged, ListXd, ArrayXd])
@registry.layers("with_array.v1")
def with_array(layer: Model[ArrayTXd, ArrayTXd], pad: int = 0) -> Model[SeqT, SeqT]:
"""Transform sequence data into a contiguous array on the way into and
out of a model. Handles a variety of sequence types: lists, padded and ragged.
If the input is an array, it is passed through unchanged.
"""
model: Model[SeqT, SeqT] = Model(
f"with_array({layer.name})",
forward,
init=init,
layers=[layer],
attrs={"pad": pad},
dims={name: layer.maybe_get_dim(name) for name in layer.dim_names},
)
return model
def forward(
model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool
) -> Tuple[SeqT, Callable]:
if isinstance(Xseq, Ragged):
return cast(Tuple[SeqT, Callable], _ragged_forward(model, Xseq, is_train))
elif isinstance(Xseq, Padded):
return cast(Tuple[SeqT, Callable], _padded_forward(model, Xseq, is_train))
elif not isinstance(Xseq, (list, tuple)):
return model.layers[0](Xseq, is_train)
else:
return cast(Tuple[SeqT, Callable], _list_forward(model, Xseq, is_train))
def init(
model: Model[SeqT, SeqT], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
) -> None:
layer: Model[ArrayXd, ArrayXd] = model.layers[0]
layer.initialize(
X=_get_array(model, X) if X is not None else X,
Y=_get_array(model, Y) if Y is not None else Y,
)
for dim_name in layer.dim_names:
value = layer.maybe_get_dim(dim_name)
if value is not None:
model.set_dim(dim_name, value)
def _get_array(model, X: SeqT) -> ArrayXd:
if isinstance(X, Ragged):
return X.dataXd
elif isinstance(X, Padded):
return X.data
elif not isinstance(X, (list, tuple)):
return cast(ArrayXd, X)
else:
return model.ops.flatten(X)
def _list_forward(
model: Model[SeqT, SeqT], Xs: ListXd, is_train: bool
) -> Tuple[ListXd, Callable]:
layer: Model[ArrayXd, ArrayXd] = model.layers[0]
pad = model.attrs["pad"]
lengths = NUMPY_OPS.asarray1i([len(seq) for seq in Xs])
Xf = layer.ops.flatten(Xs, pad=pad)
Yf, get_dXf = layer(Xf, is_train)
def backprop(dYs: ListXd) -> ListXd:
dYf = layer.ops.flatten(dYs, pad=pad)
dXf = get_dXf(dYf)
return layer.ops.unflatten(dXf, lengths, pad=pad)
return layer.ops.unflatten(Yf, lengths, pad=pad), backprop
def _ragged_forward(
model: Model[SeqT, SeqT], Xr: Ragged, is_train: bool
) -> Tuple[Ragged, Callable]:
layer: Model[ArrayXd, ArrayXd] = model.layers[0]
Y, get_dX = layer(Xr.dataXd, is_train)
def backprop(dYr: Ragged) -> Ragged:
return Ragged(get_dX(dYr.dataXd), dYr.lengths)
return Ragged(Y, Xr.lengths), backprop
def _padded_forward(
model: Model[SeqT, SeqT], Xp: Padded, is_train: bool
) -> Tuple[Padded, Callable]:
layer: Model[Array3d, Array3d] = model.layers[0]
Y, get_dX = layer(Xp.data, is_train)
def backprop(dYp: Padded) -> Padded:
assert isinstance(dYp, Padded)
dX = get_dX(dYp.data)
return Padded(dX, dYp.size_at_t, dYp.lengths, dYp.indices)
return Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices), backprop