88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
|
from typing import Callable, List, Sequence, Tuple, TypeVar, Union, cast
|
||
|
|
||
|
from ..config import registry
|
||
|
from ..model import Model
|
||
|
from ..types import ArrayXd, Padded, Ragged
|
||
|
|
||
|
InT = TypeVar("InT", bound=Union[ArrayXd, Sequence[ArrayXd], Ragged, Padded])
|
||
|
|
||
|
|
||
|
@registry.layers("Dropout.v1")
|
||
|
def Dropout(rate: float = 0.0) -> Model[InT, InT]:
|
||
|
"""Help prevent overfitting by adding a random distortion to the input data
|
||
|
during training. Specifically, cells of the input are zeroed with
|
||
|
probability determined by the `rate` argument.
|
||
|
"""
|
||
|
return Model("dropout", forward, attrs={"dropout_rate": rate, "is_enabled": True})
|
||
|
|
||
|
|
||
|
def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
|
||
|
rate = model.attrs["dropout_rate"]
|
||
|
is_enabled = model.attrs["is_enabled"] and is_train
|
||
|
if rate == 0 or not is_enabled:
|
||
|
return X, lambda dY: dY
|
||
|
elif isinstance(X, Ragged):
|
||
|
data_r, backprop = _dropout_ragged(model, X, is_train)
|
||
|
return cast(InT, data_r), backprop
|
||
|
elif isinstance(X, Padded):
|
||
|
data_p, backprop = _dropout_padded(model, X, is_train)
|
||
|
return cast(InT, data_p), backprop
|
||
|
elif isinstance(X, Sequence):
|
||
|
data_l, backprop = _dropout_lists(model, X, is_train)
|
||
|
return cast(InT, data_l), backprop
|
||
|
else:
|
||
|
data_a, backprop = _dropout_array(model, cast(ArrayXd, X), is_train)
|
||
|
return cast(InT, data_a), backprop
|
||
|
|
||
|
|
||
|
def _dropout_array(
|
||
|
model: Model[InT, InT], X: ArrayXd, is_train: bool
|
||
|
) -> Tuple[ArrayXd, Callable]:
|
||
|
rate = model.attrs["dropout_rate"]
|
||
|
mask = model.ops.get_dropout_mask(X.shape, rate)
|
||
|
|
||
|
def backprop(dY: ArrayXd) -> ArrayXd:
|
||
|
return dY * mask
|
||
|
|
||
|
return cast(ArrayXd, X * mask), backprop
|
||
|
|
||
|
|
||
|
def _dropout_padded(
|
||
|
model: Model[InT, InT], Xp: Padded, is_train: bool
|
||
|
) -> Tuple[Padded, Callable]:
|
||
|
X = Xp.data
|
||
|
mask = model.ops.get_dropout_mask(X.shape, model.attrs["dropout_rate"])
|
||
|
Y = X * mask
|
||
|
|
||
|
def backprop(dYp: Padded) -> Padded:
|
||
|
return Padded(dYp.data * mask, dYp.size_at_t, dYp.lengths, dYp.indices)
|
||
|
|
||
|
return Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices), backprop
|
||
|
|
||
|
|
||
|
def _dropout_ragged(
|
||
|
model: Model[InT, InT], Xr: Ragged, is_train: bool
|
||
|
) -> Tuple[Ragged, Callable]:
|
||
|
X = Xr.data
|
||
|
lengths = Xr.lengths
|
||
|
mask = model.ops.get_dropout_mask(X.shape, model.attrs["dropout_rate"])
|
||
|
Y = X * mask
|
||
|
|
||
|
def backprop(dYr: Ragged) -> Ragged:
|
||
|
return Ragged(dYr.data * mask, dYr.lengths)
|
||
|
|
||
|
return Ragged(Y, lengths), backprop
|
||
|
|
||
|
|
||
|
def _dropout_lists(
|
||
|
model: Model[InT, InT], Xs: Sequence[ArrayXd], is_train: bool
|
||
|
) -> Tuple[Sequence[ArrayXd], Callable]:
|
||
|
rate = model.attrs["dropout_rate"]
|
||
|
masks = [model.ops.get_dropout_mask(X.shape, rate) for X in Xs]
|
||
|
Ys = [X * mask for X, mask in zip(Xs, masks)]
|
||
|
|
||
|
def backprop(dYs: List[ArrayXd]) -> List[ArrayXd]:
|
||
|
return [dY * mask for dY, mask in zip(dYs, masks)]
|
||
|
|
||
|
return Ys, backprop
|