44 lines
1.1 KiB
Python
44 lines
1.1 KiB
Python
from typing import Any, Callable, Optional, Tuple
|
|
|
|
from ..config import registry
|
|
from ..model import Model
|
|
|
|
InT = Tuple[Any, ...]
|
|
OutT = Tuple[Any, ...]
|
|
|
|
|
|
@registry.layers("with_getitem.v1")
|
|
def with_getitem(idx: int, layer: Model) -> Model[InT, OutT]:
|
|
"""Transform data on the way into and out of a layer, by plucking an item
|
|
from a tuple.
|
|
"""
|
|
return Model(
|
|
f"with_getitem({layer.name})",
|
|
forward,
|
|
init=init,
|
|
layers=[layer],
|
|
attrs={"idx": idx},
|
|
)
|
|
|
|
|
|
def forward(
|
|
model: Model[InT, OutT], items: InT, is_train: bool
|
|
) -> Tuple[OutT, Callable]:
|
|
idx = model.attrs["idx"]
|
|
Y_i, backprop_item = model.layers[0](items[idx], is_train)
|
|
|
|
def backprop(d_output: OutT) -> InT:
|
|
dY_i = backprop_item(d_output[idx])
|
|
return d_output[:idx] + (dY_i,) + d_output[idx + 1 :]
|
|
|
|
return items[:idx] + (Y_i,) + items[idx + 1 :], backprop
|
|
|
|
|
|
def init(
|
|
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
|
|
) -> None:
|
|
idx = model.attrs["idx"]
|
|
X_i = X[idx] if X is not None else X
|
|
Y_i = Y[idx] if Y is not None else Y
|
|
model.layers[0].initialize(X=X_i, Y=Y_i)
|