67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
|
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, cast
|
||
|
|
||
|
from ..config import registry
|
||
|
from ..model import Model
|
||
|
|
||
|
InItemT = TypeVar("InItemT")
|
||
|
OutItemT = TypeVar("OutItemT")
|
||
|
ItemT = TypeVar("ItemT")
|
||
|
|
||
|
NestedT = List[List[ItemT]]
|
||
|
FlatT = List[ItemT]
|
||
|
|
||
|
|
||
|
@registry.layers("with_flatten.v2")
|
||
|
def with_flatten_v2(
|
||
|
layer: Model[FlatT[InItemT], FlatT[OutItemT]]
|
||
|
) -> Model[NestedT[InItemT], NestedT[OutItemT]]:
|
||
|
return Model(f"with_flatten({layer.name})", forward, layers=[layer], init=init)
|
||
|
|
||
|
|
||
|
def forward(
|
||
|
model: Model[NestedT[InItemT], NestedT[OutItemT]],
|
||
|
Xnest: NestedT[InItemT],
|
||
|
is_train: bool,
|
||
|
) -> Tuple[NestedT[OutItemT], Callable]:
|
||
|
layer: Model[FlatT[InItemT], FlatT[OutItemT]] = model.layers[0]
|
||
|
Xflat, lens = _flatten(Xnest)
|
||
|
Yflat, backprop_layer = layer(Xflat, is_train)
|
||
|
Ynest = _unflatten(Yflat, lens)
|
||
|
|
||
|
def backprop(dYnest: NestedT[InItemT]) -> NestedT[OutItemT]:
|
||
|
dYflat, _ = _flatten(dYnest) # type: ignore[arg-type, var-annotated]
|
||
|
# type ignore necessary for older versions of Mypy/Pydantic
|
||
|
dXflat = backprop_layer(dYflat)
|
||
|
dXnest = _unflatten(dXflat, lens)
|
||
|
return dXnest
|
||
|
|
||
|
return Ynest, backprop
|
||
|
|
||
|
|
||
|
def _flatten(nested: NestedT[ItemT]) -> Tuple[FlatT[ItemT], List[int]]:
|
||
|
flat: List = []
|
||
|
lens: List[int] = []
|
||
|
for item in nested:
|
||
|
flat.extend(item)
|
||
|
lens.append(len(item))
|
||
|
return cast(FlatT[ItemT], flat), lens
|
||
|
|
||
|
|
||
|
def _unflatten(flat: FlatT[ItemT], lens: List[int]) -> NestedT[ItemT]:
|
||
|
nested = []
|
||
|
for l in lens:
|
||
|
nested.append(flat[:l])
|
||
|
flat = flat[l:]
|
||
|
return nested
|
||
|
|
||
|
|
||
|
def init(
|
||
|
model: Model[NestedT[InItemT], NestedT[OutItemT]],
|
||
|
X: Optional[NestedT[InItemT]] = None,
|
||
|
Y: Optional[NestedT[OutItemT]] = None,
|
||
|
) -> None:
|
||
|
model.layers[0].initialize(
|
||
|
_flatten(X)[0] if X is not None else None,
|
||
|
model.layers[0].ops.xp.hstack(Y) if Y is not None else None,
|
||
|
)
|