ai-content-maker/.venv/Lib/site-packages/thinc/layers/expand_window.py

58 lines
1.6 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from typing import Callable, Tuple, TypeVar, Union, cast
from ..config import registry
from ..model import Model
from ..types import Floats2d, Ragged
InT = TypeVar("InT", Floats2d, Ragged)
@registry.layers("expand_window.v1")
def expand_window(window_size: int = 1) -> Model[InT, InT]:
"""For each vector in an input, construct an output vector that contains the
input and a window of surrounding vectors. This is one step in a convolution.
"""
return Model("expand_window", forward, attrs={"window_size": window_size})
def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
if isinstance(X, Ragged):
return _expand_window_ragged(model, X)
else:
return _expand_window_floats(model, X)
def _expand_window_floats(
model: Model[InT, InT], X: Floats2d
) -> Tuple[Floats2d, Callable]:
nW = model.attrs["window_size"]
if len(X) > 0:
Y = model.ops.seq2col(X, nW)
else:
assert len(X) == 0
Y = model.ops.tile(X, (nW * 2) + 1)
def backprop(dY: Floats2d) -> Floats2d:
return model.ops.backprop_seq2col(dY, nW)
return Y, backprop
def _expand_window_ragged(
model: Model[InT, InT], Xr: Ragged
) -> Tuple[Ragged, Callable]:
nW = model.attrs["window_size"]
Y = Ragged(
model.ops.seq2col(cast(Floats2d, Xr.data), nW, lengths=Xr.lengths), Xr.lengths
)
def backprop(dYr: Ragged) -> Ragged:
return Ragged(
model.ops.backprop_seq2col(
cast(Floats2d, dYr.data), nW, lengths=Xr.lengths
),
Xr.lengths,
)
return Y, backprop