ai-content-maker/.venv/Lib/site-packages/thinc/layers/with_flatten_v2.py

67 lines
1.9 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, cast
from ..config import registry
from ..model import Model
InItemT = TypeVar("InItemT")
OutItemT = TypeVar("OutItemT")
ItemT = TypeVar("ItemT")
NestedT = List[List[ItemT]]
FlatT = List[ItemT]
@registry.layers("with_flatten.v2")
def with_flatten_v2(
layer: Model[FlatT[InItemT], FlatT[OutItemT]]
) -> Model[NestedT[InItemT], NestedT[OutItemT]]:
return Model(f"with_flatten({layer.name})", forward, layers=[layer], init=init)
def forward(
model: Model[NestedT[InItemT], NestedT[OutItemT]],
Xnest: NestedT[InItemT],
is_train: bool,
) -> Tuple[NestedT[OutItemT], Callable]:
layer: Model[FlatT[InItemT], FlatT[OutItemT]] = model.layers[0]
Xflat, lens = _flatten(Xnest)
Yflat, backprop_layer = layer(Xflat, is_train)
Ynest = _unflatten(Yflat, lens)
def backprop(dYnest: NestedT[InItemT]) -> NestedT[OutItemT]:
dYflat, _ = _flatten(dYnest) # type: ignore[arg-type, var-annotated]
# type ignore necessary for older versions of Mypy/Pydantic
dXflat = backprop_layer(dYflat)
dXnest = _unflatten(dXflat, lens)
return dXnest
return Ynest, backprop
def _flatten(nested: NestedT[ItemT]) -> Tuple[FlatT[ItemT], List[int]]:
flat: List = []
lens: List[int] = []
for item in nested:
flat.extend(item)
lens.append(len(item))
return cast(FlatT[ItemT], flat), lens
def _unflatten(flat: FlatT[ItemT], lens: List[int]) -> NestedT[ItemT]:
nested = []
for l in lens:
nested.append(flat[:l])
flat = flat[l:]
return nested
def init(
model: Model[NestedT[InItemT], NestedT[OutItemT]],
X: Optional[NestedT[InItemT]] = None,
Y: Optional[NestedT[OutItemT]] = None,
) -> None:
model.layers[0].initialize(
_flatten(X)[0] if X is not None else None,
model.layers[0].ops.xp.hstack(Y) if Y is not None else None,
)