69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
|
|
|
|
from ..config import registry
|
|
from ..model import Model
|
|
from ..types import ArrayXd, XY_XY_OutT
|
|
from ..util import get_width
|
|
|
|
InT = TypeVar("InT", bound=Any)
|
|
OutT = TypeVar("OutT", bound=ArrayXd)
|
|
|
|
|
|
@registry.layers("add.v1")
|
|
def add(
|
|
layer1: Model[InT, OutT], layer2: Model[InT, OutT], *layers: Model
|
|
) -> Model[InT, XY_XY_OutT]:
|
|
"""Compose two or more models `f`, `g`, etc, such that their outputs are
|
|
added, i.e. `add(f, g)(x)` computes `f(x) + g(x)`.
|
|
"""
|
|
layers = (layer1, layer2) + layers
|
|
if layers[0].name == "add":
|
|
layers[0].layers.extend(layers[1:])
|
|
return layers[0]
|
|
|
|
# only add an nI dimension if each sub-layer has one
|
|
dims: Dict[str, Optional[int]] = {"nO": None}
|
|
if all(node.has_dim("nI") in [True, None] for node in layers):
|
|
dims = {"nO": None, "nI": None}
|
|
|
|
return Model("add", forward, init=init, dims=dims, layers=layers)
|
|
|
|
|
|
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
|
|
if not model.layers:
|
|
return X, lambda dY: dY
|
|
Y, first_callback = model.layers[0](X, is_train=is_train)
|
|
callbacks = []
|
|
for layer in model.layers[1:]:
|
|
layer_Y, layer_callback = layer(X, is_train=is_train)
|
|
Y += layer_Y
|
|
callbacks.append(layer_callback)
|
|
|
|
def backprop(dY: InT) -> OutT:
|
|
dX = first_callback(dY)
|
|
for callback in callbacks:
|
|
dX += callback(dY)
|
|
return dX
|
|
|
|
return Y, backprop
|
|
|
|
|
|
def init(
|
|
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
|
|
) -> None:
|
|
if X is not None:
|
|
if model.has_dim("nI") is not False:
|
|
model.set_dim("nI", get_width(X))
|
|
for layer in model.layers:
|
|
if layer.has_dim("nI") is not False:
|
|
layer.set_dim("nI", get_width(X))
|
|
if Y is not None:
|
|
if model.has_dim("nO") is not False:
|
|
model.set_dim("nO", get_width(Y))
|
|
for layer in model.layers:
|
|
if layer.has_dim("nO") is not False:
|
|
layer.set_dim("nO", get_width(Y))
|
|
for layer in model.layers:
|
|
layer.initialize(X=X, Y=Y)
|
|
model.set_dim("nO", model.layers[0].get_dim("nO"))
|