58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
|
from typing import Callable, Tuple, TypeVar, Union, cast
|
||
|
|
||
|
from ..config import registry
|
||
|
from ..model import Model
|
||
|
from ..types import Floats2d, Ragged
|
||
|
|
||
|
InT = TypeVar("InT", Floats2d, Ragged)
|
||
|
|
||
|
|
||
|
@registry.layers("expand_window.v1")
|
||
|
def expand_window(window_size: int = 1) -> Model[InT, InT]:
|
||
|
"""For each vector in an input, construct an output vector that contains the
|
||
|
input and a window of surrounding vectors. This is one step in a convolution.
|
||
|
"""
|
||
|
return Model("expand_window", forward, attrs={"window_size": window_size})
|
||
|
|
||
|
|
||
|
def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
|
||
|
if isinstance(X, Ragged):
|
||
|
return _expand_window_ragged(model, X)
|
||
|
else:
|
||
|
return _expand_window_floats(model, X)
|
||
|
|
||
|
|
||
|
def _expand_window_floats(
|
||
|
model: Model[InT, InT], X: Floats2d
|
||
|
) -> Tuple[Floats2d, Callable]:
|
||
|
nW = model.attrs["window_size"]
|
||
|
if len(X) > 0:
|
||
|
Y = model.ops.seq2col(X, nW)
|
||
|
else:
|
||
|
assert len(X) == 0
|
||
|
Y = model.ops.tile(X, (nW * 2) + 1)
|
||
|
|
||
|
def backprop(dY: Floats2d) -> Floats2d:
|
||
|
return model.ops.backprop_seq2col(dY, nW)
|
||
|
|
||
|
return Y, backprop
|
||
|
|
||
|
|
||
|
def _expand_window_ragged(
|
||
|
model: Model[InT, InT], Xr: Ragged
|
||
|
) -> Tuple[Ragged, Callable]:
|
||
|
nW = model.attrs["window_size"]
|
||
|
Y = Ragged(
|
||
|
model.ops.seq2col(cast(Floats2d, Xr.data), nW, lengths=Xr.lengths), Xr.lengths
|
||
|
)
|
||
|
|
||
|
def backprop(dYr: Ragged) -> Ragged:
|
||
|
return Ragged(
|
||
|
model.ops.backprop_seq2col(
|
||
|
cast(Floats2d, dYr.data), nW, lengths=Xr.lengths
|
||
|
),
|
||
|
Xr.lengths,
|
||
|
)
|
||
|
|
||
|
return Y, backprop
|