31 lines
815 B
Python
31 lines
815 B
Python
|
from typing import Callable, Tuple, cast
|
||
|
|
||
|
from ..config import registry
|
||
|
from ..model import Model
|
||
|
from ..types import Floats2d, Ragged
|
||
|
from ..util import ArrayInfo
|
||
|
|
||
|
InT = Ragged
|
||
|
OutT = Floats2d
|
||
|
|
||
|
|
||
|
@registry.layers("reduce_first.v1")
|
||
|
def reduce_first() -> Model[InT, OutT]:
|
||
|
"""Reduce ragged-formatted sequences to their first element."""
|
||
|
return Model("reduce_first", forward)
|
||
|
|
||
|
|
||
|
def forward(
|
||
|
model: Model[InT, OutT], Xr: InT, is_train: bool
|
||
|
) -> Tuple[OutT, Callable[[OutT], InT]]:
|
||
|
Y, starts_ends = model.ops.reduce_first(cast(Floats2d, Xr.data), Xr.lengths)
|
||
|
|
||
|
array_info = ArrayInfo.from_array(Y)
|
||
|
|
||
|
def backprop(dY: OutT) -> InT:
|
||
|
array_info.check_consistency(dY)
|
||
|
dX = model.ops.backprop_reduce_first(dY, starts_ends)
|
||
|
return Ragged(dX, Xr.lengths)
|
||
|
|
||
|
return Y, backprop
|