72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
|
from typing import Callable, Optional, Tuple, cast
|
||
|
|
||
|
from ..backends import Ops
|
||
|
from ..config import registry
|
||
|
from ..model import Model
|
||
|
from ..types import Padded
|
||
|
|
||
|
InT = Padded
|
||
|
OutT = Padded
|
||
|
|
||
|
|
||
|
@registry.layers("bidirectional.v1")
|
||
|
def bidirectional(
|
||
|
l2r: Model[InT, OutT], r2l: Optional[Model[InT, OutT]] = None
|
||
|
) -> Model[InT, OutT]:
|
||
|
"""Stitch two RNN models into a bidirectional layer. Expects squared sequences."""
|
||
|
if r2l is None:
|
||
|
r2l = l2r.copy()
|
||
|
return Model(f"bi{l2r.name}", forward, layers=[l2r, r2l], init=init)
|
||
|
|
||
|
|
||
|
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
|
||
|
l2r, r2l = model.layers
|
||
|
X_rev = _reverse(model.ops, X)
|
||
|
l2r_Z, bp_l2r_Z = l2r(X, is_train)
|
||
|
r2l_Z, bp_r2l_Z = r2l(X_rev, is_train)
|
||
|
Z = _concatenate(model.ops, l2r_Z, r2l_Z)
|
||
|
|
||
|
def backprop(dZ: OutT) -> InT:
|
||
|
d_l2r_Z, d_r2l_Z = _split(model.ops, dZ)
|
||
|
dX_l2r = bp_l2r_Z(d_l2r_Z)
|
||
|
dX_r2l = bp_r2l_Z(d_r2l_Z)
|
||
|
return _sum(dX_l2r, dX_r2l)
|
||
|
|
||
|
return Z, backprop
|
||
|
|
||
|
|
||
|
def init(
|
||
|
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
|
||
|
) -> None:
|
||
|
(Y1, Y2) = _split(model.ops, Y) if Y is not None else (None, None)
|
||
|
model.layers[0].initialize(X=X, Y=Y1)
|
||
|
model.layers[1].initialize(X=X, Y=Y2)
|
||
|
|
||
|
|
||
|
def _reverse(ops: Ops, Xp: Padded) -> Padded:
|
||
|
return Padded(Xp.data[::1], Xp.size_at_t, Xp.lengths, Xp.indices)
|
||
|
|
||
|
|
||
|
def _concatenate(ops: Ops, l2r: Padded, r2l: Padded) -> Padded:
|
||
|
return Padded(
|
||
|
ops.xp.concatenate((l2r.data, r2l.data), axis=-1),
|
||
|
l2r.size_at_t,
|
||
|
l2r.lengths,
|
||
|
l2r.indices,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _split(ops: Ops, Xp: Padded) -> Tuple[Padded, Padded]:
|
||
|
half = Xp.data.shape[-1] // 2
|
||
|
# I don't know how to write these ellipsis in the overloads :(
|
||
|
X_l2r = Xp.data[cast(Tuple[slice, slice], (..., slice(None, half)))]
|
||
|
X_r2l = Xp.data[cast(Tuple[slice, slice], (..., slice(half)))]
|
||
|
return (
|
||
|
Padded(X_l2r, Xp.size_at_t, Xp.lengths, Xp.indices),
|
||
|
Padded(X_r2l, Xp.size_at_t, Xp.lengths, Xp.indices),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _sum(Xp: Padded, Yp: Padded) -> Padded:
|
||
|
return Padded(Xp.data + Yp.data, Xp.size_at_t, Xp.lengths, Xp.indices)
|