28 lines
702 B
Python
28 lines
702 B
Python
from typing import Callable, Tuple, TypeVar, cast
|
|
|
|
from ..config import registry
|
|
from ..model import Model
|
|
from ..types import List2d, Padded
|
|
|
|
InT = Padded
|
|
OutT = TypeVar("OutT", bound=List2d)
|
|
|
|
|
|
@registry.layers("padded2list.v1")
|
|
def padded2list() -> Model[InT, OutT]:
|
|
"""Create a layer to convert a Padded input into a list of arrays."""
|
|
return Model(f"padded2list", forward)
|
|
|
|
|
|
def forward(
|
|
model: Model[InT, OutT], Xp: InT, is_train: bool
|
|
) -> Tuple[OutT, Callable[[OutT], InT]]:
|
|
Ys = cast(OutT, model.ops.padded2list(Xp))
|
|
|
|
def backprop(dYs: OutT) -> InT:
|
|
dYp = model.ops.list2padded(dYs)
|
|
assert isinstance(dYp, Padded)
|
|
return dYp
|
|
|
|
return Ys, backprop
|