57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
from typing import Callable, List, Optional, Tuple, TypeVar, cast
|
|
|
|
from ..config import registry
|
|
from ..model import Model
|
|
from ..types import Array2d, Array3d
|
|
|
|
InT = TypeVar("InT", bound=Array3d)
|
|
OutT = TypeVar("OutT", bound=Array2d)
|
|
|
|
|
|
@registry.layers("with_reshape.v1")
|
|
def with_reshape(layer: Model[OutT, OutT]) -> Model[InT, InT]:
|
|
"""Reshape data on the way into and out from a layer."""
|
|
return Model(
|
|
f"with_reshape({layer.name})",
|
|
forward,
|
|
init=init,
|
|
layers=[layer],
|
|
dims={"nO": None, "nI": None},
|
|
)
|
|
|
|
|
|
def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
|
|
layer = model.layers[0]
|
|
initial_shape = X.shape
|
|
final_shape = list(initial_shape[:-1]) + [layer.get_dim("nO")]
|
|
nB = X.shape[0]
|
|
nT = X.shape[1]
|
|
X2d = model.ops.reshape(X, (-1, X.shape[2]))
|
|
Y2d, Y2d_backprop = layer(X2d, is_train=is_train)
|
|
Y = model.ops.reshape3(Y2d, *final_shape)
|
|
|
|
def backprop(dY: InT) -> InT:
|
|
reshaped = model.ops.reshape2(dY, nB * nT, -1)
|
|
return Y2d_backprop(model.ops.reshape3(reshaped, *initial_shape))
|
|
|
|
return cast(InT, Y), backprop
|
|
|
|
|
|
def init(
|
|
model: Model[InT, InT], X: Optional[Array3d] = None, Y: Optional[Array3d] = None
|
|
) -> None:
|
|
layer = model.layers[0]
|
|
if X is None and Y is None:
|
|
layer.initialize()
|
|
X2d: Optional[Array2d] = None
|
|
Y2d: Optional[Array2d] = None
|
|
if X is not None:
|
|
X2d = cast(Array2d, model.ops.reshape(X, (-1, X.shape[-1])))
|
|
if Y is not None:
|
|
Y2d = cast(Array2d, model.ops.reshape(Y, (-1, Y.shape[-1])))
|
|
layer.initialize(X=X2d, Y=Y2d)
|
|
if layer.has_dim("nI"):
|
|
model.set_dim("nI", layer.get_dim("nI"))
|
|
if layer.has_dim("nO"):
|
|
model.set_dim("nO", layer.get_dim("nO"))
|