183 lines
6.3 KiB
Python
183 lines
6.3 KiB
Python
from functools import partial
|
|
from typing import Callable, Optional, Tuple, cast
|
|
|
|
from ..backends import Ops
|
|
from ..config import registry
|
|
from ..initializers import glorot_uniform_init, zero_init
|
|
from ..model import Model
|
|
from ..types import Floats1d, Floats2d, Floats4d, Padded, Ragged
|
|
from ..util import get_width
|
|
from .noop import noop
|
|
|
|
|
|
@registry.layers("LSTM.v1")
|
|
def LSTM(
|
|
nO: Optional[int] = None,
|
|
nI: Optional[int] = None,
|
|
*,
|
|
bi: bool = False,
|
|
depth: int = 1,
|
|
dropout: float = 0.0,
|
|
init_W: Optional[Callable] = None,
|
|
init_b: Optional[Callable] = None,
|
|
) -> Model[Padded, Padded]:
|
|
if depth == 0:
|
|
msg = "LSTM depth must be at least 1. Maybe we should make this a noop?"
|
|
raise ValueError(msg)
|
|
if init_W is None:
|
|
init_W = glorot_uniform_init
|
|
if init_b is None:
|
|
init_b = zero_init
|
|
|
|
model: Model[Padded, Padded] = Model(
|
|
"lstm",
|
|
forward,
|
|
dims={"nO": nO, "nI": nI, "depth": depth, "dirs": 1 + int(bi)},
|
|
attrs={"registry_name": "LSTM.v1", "dropout_rate": dropout},
|
|
params={"LSTM": None, "HC0": None},
|
|
init=partial(init, init_W, init_b),
|
|
)
|
|
return model
|
|
|
|
|
|
@registry.layers("PyTorchLSTM.v1")
|
|
def PyTorchLSTM(
|
|
nO: int, nI: int, *, bi: bool = False, depth: int = 1, dropout: float = 0.0
|
|
) -> Model[Padded, Padded]:
|
|
import torch.nn
|
|
|
|
from .pytorchwrapper import PyTorchRNNWrapper
|
|
from .with_padded import with_padded
|
|
|
|
if depth == 0:
|
|
return noop() # type: ignore[misc]
|
|
nH = nO
|
|
if bi:
|
|
nH = nO // 2
|
|
pytorch_rnn = PyTorchRNNWrapper(
|
|
torch.nn.LSTM(nI, nH, depth, bidirectional=bi, dropout=dropout)
|
|
)
|
|
pytorch_rnn.set_dim("nO", nO)
|
|
pytorch_rnn.set_dim("nI", nI)
|
|
return with_padded(pytorch_rnn)
|
|
|
|
|
|
def init(
|
|
init_W: Callable,
|
|
init_b: Callable,
|
|
model: Model,
|
|
X: Optional[Padded] = None,
|
|
Y: Optional[Padded] = None,
|
|
) -> None:
|
|
if X is not None:
|
|
model.set_dim("nI", get_width(X))
|
|
if Y is not None:
|
|
model.set_dim("nO", get_width(Y))
|
|
nH = int(model.get_dim("nO") / model.get_dim("dirs"))
|
|
nI = model.get_dim("nI")
|
|
depth = model.get_dim("depth")
|
|
dirs = model.get_dim("dirs")
|
|
# It's easiest to use the initializer if we alloc the weights separately
|
|
# and then stick them all together afterwards. The order matters here:
|
|
# we need to keep the same format that CuDNN expects.
|
|
params = []
|
|
# Convenience
|
|
init_W = partial(init_W, model.ops)
|
|
init_b = partial(init_b, model.ops)
|
|
layer_nI = nI
|
|
for i in range(depth):
|
|
for j in range(dirs):
|
|
# Input-to-gates weights and biases.
|
|
params.append(init_W((nH, layer_nI)))
|
|
params.append(init_W((nH, layer_nI)))
|
|
params.append(init_W((nH, layer_nI)))
|
|
params.append(init_W((nH, layer_nI)))
|
|
params.append(init_b((nH,)))
|
|
params.append(init_b((nH,)))
|
|
params.append(init_b((nH,)))
|
|
params.append(init_b((nH,)))
|
|
# Hidden-to-gates weights and biases
|
|
params.append(init_W((nH, nH)))
|
|
params.append(init_W((nH, nH)))
|
|
params.append(init_W((nH, nH)))
|
|
params.append(init_W((nH, nH)))
|
|
params.append(init_b((nH,)))
|
|
params.append(init_b((nH,)))
|
|
params.append(init_b((nH,)))
|
|
params.append(init_b((nH,)))
|
|
layer_nI = nH * dirs
|
|
model.set_param("LSTM", model.ops.xp.concatenate([p.ravel() for p in params]))
|
|
model.set_param("HC0", zero_init(model.ops, (2, depth, dirs, nH)))
|
|
size = model.get_param("LSTM").size
|
|
expected = 4 * dirs * nH * (nH + nI) + dirs * (8 * nH)
|
|
for _ in range(1, depth):
|
|
expected += 4 * dirs * (nH + nH * dirs) * nH + dirs * (8 * nH)
|
|
assert size == expected, (size, expected)
|
|
|
|
|
|
def forward(
|
|
model: Model[Padded, Padded], Xp: Padded, is_train: bool
|
|
) -> Tuple[Padded, Callable]:
|
|
dropout = model.attrs["dropout_rate"]
|
|
Xr = _padded_to_packed(model.ops, Xp)
|
|
LSTM = cast(Floats1d, model.get_param("LSTM"))
|
|
HC0 = cast(Floats4d, model.get_param("HC0"))
|
|
H0 = HC0[0]
|
|
C0 = HC0[1]
|
|
if is_train:
|
|
# Apply dropout over *weights*, not *activations*. RNN activations are
|
|
# heavily correlated, so dropout over the activations is less effective.
|
|
# This trick was explained in Yarin Gal's thesis, and popularised by
|
|
# Smerity in the AWD-LSTM. It also means we can do the dropout outside
|
|
# of the backend, improving compatibility.
|
|
mask = cast(Floats1d, model.ops.get_dropout_mask(LSTM.shape, dropout))
|
|
LSTM = LSTM * mask
|
|
Y, fwd_state = model.ops.lstm_forward_training(
|
|
LSTM, H0, C0, cast(Floats2d, Xr.data), Xr.lengths
|
|
)
|
|
else:
|
|
Y = model.ops.lstm_forward_inference(
|
|
LSTM, H0, C0, cast(Floats2d, Xr.data), Xr.lengths
|
|
)
|
|
fwd_state = tuple()
|
|
assert Y.shape == (Xr.data.shape[0], Y.shape[1]), (Xr.data.shape, Y.shape)
|
|
Yp = _packed_to_padded(model.ops, Ragged(Y, Xr.lengths), Xp)
|
|
|
|
def backprop(dYp: Padded) -> Padded:
|
|
assert fwd_state
|
|
dYr = _padded_to_packed(model.ops, dYp)
|
|
dX, dLSTM = model.ops.backprop_lstm(
|
|
cast(Floats2d, dYr.data), dYr.lengths, LSTM, fwd_state
|
|
)
|
|
dLSTM *= mask
|
|
model.inc_grad("LSTM", dLSTM)
|
|
return _packed_to_padded(model.ops, Ragged(dX, dYr.lengths), dYp)
|
|
|
|
return Yp, backprop
|
|
|
|
|
|
def _padded_to_packed(ops: Ops, Xp: Padded) -> Ragged:
|
|
"""Strip padding from a padded sequence."""
|
|
assert Xp.lengths.sum() == Xp.size_at_t.sum(), (
|
|
Xp.lengths.sum(),
|
|
Xp.size_at_t.sum(),
|
|
)
|
|
Y = ops.alloc2f(Xp.lengths.sum(), Xp.data.shape[2])
|
|
start = 0
|
|
for t in range(Xp.size_at_t.shape[0]):
|
|
batch_size = Xp.size_at_t[t]
|
|
Y[start : start + batch_size] = Xp.data[t, :batch_size] # type: ignore[assignment]
|
|
start += batch_size
|
|
return Ragged(Y, Xp.size_at_t)
|
|
|
|
|
|
def _packed_to_padded(ops: Ops, Xr: Ragged, Xp: Padded) -> Padded:
|
|
Y = ops.alloc3f(Xp.data.shape[0], Xp.data.shape[1], Xr.data.shape[1])
|
|
X = cast(Floats2d, Xr.data)
|
|
start = 0
|
|
for t in range(Xp.size_at_t.shape[0]):
|
|
batch_size = Xp.size_at_t[t]
|
|
Y[t, :batch_size] = X[start : start + batch_size]
|
|
start += batch_size
|
|
return Padded(Y, size_at_t=Xp.size_at_t, lengths=Xp.lengths, indices=Xp.indices)
|