ai-content-maker/.venv/Lib/site-packages/thinc/layers/residual.py

65 lines
2.0 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from typing import Callable, List, Optional, Tuple, TypeVar
from ..config import registry
from ..model import Model
from ..types import Floats1d, Floats2d, Floats3d, Floats4d, FloatsXd, Padded, Ragged
# fmt: off
InT = TypeVar(
"InT", List[Floats1d], List[Floats2d], List[Floats3d], List[Floats4d],
Ragged, Padded, FloatsXd, Floats1d, Floats2d, Floats3d, Floats4d)
# fmt: on
@registry.layers("residual.v1")
def residual(layer: Model[InT, InT]) -> Model[InT, InT]:
return Model(
f"residual({layer.name})",
forward,
init=init,
layers=[layer],
dims={
"nO": layer.get_dim("nO") if layer.has_dim("nO") else None,
"nI": layer.get_dim("nI") if layer.has_dim("nI") else None,
},
)
def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
def backprop(d_output: InT) -> InT:
dX = backprop_layer(d_output)
if isinstance(d_output, list):
return [d_output[i] + dX[i] for i in range(len(d_output))]
elif isinstance(d_output, Ragged):
return Ragged(d_output.data + dX.data, dX.lengths)
elif isinstance(X, Padded):
dX.data += d_output.data
return dX
else:
return d_output + dX
Y, backprop_layer = model.layers[0](X, is_train)
if isinstance(X, list):
return [X[i] + Y[i] for i in range(len(X))], backprop
elif isinstance(X, Ragged):
return Ragged(X.data + Y.data, X.lengths), backprop
elif isinstance(X, Padded):
Y.data += X.data
return Y, backprop
else:
return X + Y, backprop
def init(
model: Model[InT, InT], X: Optional[InT] = None, Y: Optional[InT] = None
) -> None:
first_layer = model.layers[0]
if first_layer.has_dim("nO") is None:
first_layer.initialize(X=X, Y=Y)
else:
first_layer.initialize(X=X)
if first_layer.has_dim("nO"):
model.set_dim("nO", first_layer.get_dim("nO"))
if first_layer.has_dim("nI"):
model.set_dim("nI", first_layer.get_dim("nI"))