ai-content-maker/.venv/Lib/site-packages/thinc/layers/map_list.py

37 lines
1.0 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from typing import Callable, List, Optional, Tuple, TypeVar
from ..model import Model
InT = TypeVar("InT")
OutT = TypeVar("OutT")
def map_list(layer: Model[InT, OutT]) -> Model[List[InT], List[OutT]]:
"""Create a model that maps a child layer across list inputs."""
return Model("map_list", forward, layers=[layer], init=init)
def forward(
model: Model[List[InT], List[OutT]], Xs: List[InT], is_train: bool
) -> Tuple[List[OutT], Callable[[List[OutT]], List[InT]]]:
layer = model.layers[0]
Ys = []
callbacks = []
for X in Xs:
Y, get_dX = layer(X, is_train)
Ys.append(Y)
callbacks.append(get_dX)
def backprop_map_list(dYs: List[OutT]) -> List[InT]:
return [callback(dY) for callback, dY in zip(callbacks, dYs)]
return Ys, backprop_map_list
def init(
model: Model[List[InT], List[OutT]],
X: Optional[List[InT]] = None,
Y: Optional[List[OutT]] = None,
) -> None:
model.layers[0].initialize(X=X[0] if X else None, Y=Y[0] if Y else None)