156 lines
5.1 KiB
Python
156 lines
5.1 KiB
Python
|
from typing import (
|
||
|
Any,
|
||
|
Callable,
|
||
|
Dict,
|
||
|
List,
|
||
|
Optional,
|
||
|
Sequence,
|
||
|
Tuple,
|
||
|
TypeVar,
|
||
|
Union,
|
||
|
cast,
|
||
|
)
|
||
|
|
||
|
from ..backends import NumpyOps
|
||
|
from ..config import registry
|
||
|
from ..model import Model
|
||
|
from ..types import Array2d, Ragged, XY_XY_OutT
|
||
|
from ..util import get_width
|
||
|
from .noop import noop
|
||
|
|
||
|
NUMPY_OPS = NumpyOps()
|
||
|
|
||
|
|
||
|
InT = TypeVar("InT", bound=Any)
|
||
|
OutT = TypeVar("OutT", bound=Union[Array2d, Sequence[Array2d], Ragged])
|
||
|
|
||
|
|
||
|
@registry.layers("concatenate.v1")
|
||
|
def concatenate(*layers: Model) -> Model[InT, XY_XY_OutT]:
|
||
|
"""Compose two or more models `f`, `g`, etc, such that their outputs are
|
||
|
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`.
|
||
|
Also supports chaining more than 2 layers.
|
||
|
"""
|
||
|
if not layers:
|
||
|
return cast(Model[InT, XY_XY_OutT], noop())
|
||
|
elif len(layers) == 1:
|
||
|
return layers[0]
|
||
|
elif layers[0]._func is forward:
|
||
|
layers[0].layers.extend(layers[1:])
|
||
|
return layers[0]
|
||
|
|
||
|
# only add an nI dimension if each sub-layer has one
|
||
|
dims: Dict[str, Optional[int]] = {"nO": None}
|
||
|
if all(node.has_dim("nI") in [True, None] for node in layers):
|
||
|
dims = {"nO": None, "nI": None}
|
||
|
|
||
|
return Model(
|
||
|
"|".join(layer.name for layer in layers),
|
||
|
forward,
|
||
|
init=init,
|
||
|
dims=dims,
|
||
|
layers=layers,
|
||
|
)
|
||
|
|
||
|
|
||
|
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
|
||
|
Ys, callbacks = zip(*[layer(X, is_train=is_train) for layer in model.layers])
|
||
|
if isinstance(Ys[0], list):
|
||
|
data_l, backprop = _list_forward(model, X, Ys, callbacks, is_train)
|
||
|
return cast(OutT, data_l), backprop
|
||
|
elif isinstance(Ys[0], Ragged):
|
||
|
data_r, backprop = _ragged_forward(model, X, Ys, callbacks, is_train)
|
||
|
return cast(OutT, data_r), backprop
|
||
|
else:
|
||
|
data_a, backprop = _array_forward(model, X, Ys, callbacks, is_train)
|
||
|
return cast(OutT, data_a), backprop
|
||
|
|
||
|
|
||
|
def _array_forward(
|
||
|
model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool
|
||
|
) -> Tuple[Array2d, Callable]:
|
||
|
widths = [Y.shape[1] for Y in Ys]
|
||
|
output = model.ops.xp.hstack(Ys)
|
||
|
|
||
|
def backprop(d_output: Array2d) -> InT:
|
||
|
dY = model.ops.as_contig(d_output[:, : widths[0]])
|
||
|
dX = callbacks[0](dY)
|
||
|
start = widths[0]
|
||
|
add_gradients = hasattr(dX, "__add__") or hasattr(dX, "__iadd__")
|
||
|
add_gradients_data = hasattr(dX, "data") and (
|
||
|
hasattr(dX.data, "__add__") or hasattr(dX.data, "__iadd__")
|
||
|
)
|
||
|
for bwd, width in zip(callbacks[1:], widths[1:]):
|
||
|
dY = model.ops.as_contig(d_output[:, start : start + width])
|
||
|
gradient = bwd(dY)
|
||
|
if add_gradients:
|
||
|
dX += gradient
|
||
|
elif add_gradients_data:
|
||
|
dX.data += gradient.data
|
||
|
start += width
|
||
|
return dX
|
||
|
|
||
|
return output, backprop
|
||
|
|
||
|
|
||
|
def _ragged_forward(
|
||
|
model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool
|
||
|
) -> Tuple[Ragged, Callable]:
|
||
|
|
||
|
widths = [Y.dataXd.shape[1] for Y in Ys]
|
||
|
output = Ragged(model.ops.xp.hstack([y.data for y in Ys]), Ys[0].lengths)
|
||
|
|
||
|
def backprop(d_output: Ragged) -> InT:
|
||
|
d_array = d_output.data
|
||
|
dY = Ragged(model.ops.as_contig(d_array[:, : widths[0]]), d_output.lengths)
|
||
|
dX = callbacks[0](dY)
|
||
|
start = widths[0]
|
||
|
for bwd, width in zip(callbacks[1:], widths[1:]):
|
||
|
dY = Ragged(
|
||
|
model.ops.as_contig(d_array[:, start : start + width]), d_output.lengths
|
||
|
)
|
||
|
dX += bwd(dY)
|
||
|
start += width
|
||
|
return dX
|
||
|
|
||
|
return output, backprop
|
||
|
|
||
|
|
||
|
def _list_forward(
|
||
|
model: Model[InT, OutT], X, Ys: List, callbacks, is_train: bool
|
||
|
) -> Tuple[Sequence[Array2d], Callable]:
|
||
|
def backprop(d_output: Sequence[Array2d]) -> InT:
|
||
|
d_out_array = model.ops.xp.concatenate(d_output, axis=0)
|
||
|
dY = model.ops.as_contig(d_out_array[:, : widths[0]])
|
||
|
# We want to generalize unflatten later.
|
||
|
dY = model.ops.unflatten(dY, lengths)
|
||
|
dX = callbacks[0](dY)
|
||
|
start = widths[0]
|
||
|
for bwd, width in zip(callbacks[1:], widths[1:]):
|
||
|
dY = model.ops.as_contig(d_out_array[:, start : start + width])
|
||
|
dY = model.ops.unflatten(dY, lengths)
|
||
|
dX += bwd(dY)
|
||
|
start += width
|
||
|
return dX
|
||
|
|
||
|
lengths = NUMPY_OPS.asarray1i([len(x) for x in X])
|
||
|
Ys = [model.ops.xp.concatenate(Y, axis=0) for Y in Ys]
|
||
|
widths = [Y.shape[1] for Y in Ys]
|
||
|
out_array = model.ops.xp.hstack(Ys)
|
||
|
return model.ops.unflatten(out_array, lengths), backprop
|
||
|
|
||
|
|
||
|
def init(
|
||
|
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
|
||
|
) -> None:
|
||
|
if X is not None:
|
||
|
if model.has_dim("nI") is not False:
|
||
|
model.set_dim("nI", get_width(X))
|
||
|
for layer in model.layers:
|
||
|
if layer.has_dim("nI") is not False:
|
||
|
layer.set_dim("nI", get_width(X))
|
||
|
for layer in model.layers:
|
||
|
layer.initialize(X=X, Y=Y)
|
||
|
if all([layer.has_dim("nO") for layer in model.layers]):
|
||
|
model.set_dim("nO", sum(layer.get_dim("nO") for layer in model.layers))
|