from typing import Callable, Optional, Tuple, cast from ..backends import Ops from ..config import registry from ..model import Model from ..types import Floats2d from ..util import get_width InT = Floats2d @registry.layers("LayerNorm.v1") def LayerNorm(nI: Optional[int] = None) -> Model[InT, InT]: return Model( "layernorm", forward, init=init, dims={"nI": nI, "nO": nI}, params={"G": None, "b": None}, ) def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]: N, mu, var = _get_moments(model.ops, X) Xhat = (X - mu) * var ** (-1.0 / 2.0) Y, backprop_rescale = _begin_update_scale_shift(model, Xhat) def backprop(dY: InT) -> InT: dY = backprop_rescale(dY) dist, sum_dy, sum_dy_dist = _get_d_moments(model.ops, dY, X, mu) d_xhat = N * dY - sum_dy - dist * var ** (-1.0) * sum_dy_dist d_xhat *= var ** (-1.0 / 2) d_xhat /= N return d_xhat return Y, backprop def init( model: Model[InT, InT], X: Optional[InT] = None, Y: Optional[InT] = None ) -> None: if X is not None: X_width = get_width(X) model.set_dim("nI", X_width) model.set_dim("nO", X_width) elif Y is not None: Y_width = get_width(Y) model.set_dim("nI", Y_width) model.set_dim("nO", Y_width) nI = model.get_dim("nI") if not model.has_dim("nO"): model.set_dim("nO", nI) model.set_param("G", model.ops.alloc1f(nI) + 1) model.set_param("b", model.ops.alloc1f(nI)) assert model.get_dim("nO") is not None def _begin_update_scale_shift(model: Model[InT, InT], X: InT) -> Tuple[InT, Callable]: G = model.get_param("G") b = model.get_param("b") Y = X * G Y += b def finish_update_scale_shift(dY: InT) -> InT: model.inc_grad("b", dY.sum(axis=0)) model.inc_grad("G", (dY * X).sum(axis=0)) return dY * G return Y, finish_update_scale_shift def _get_moments(ops: Ops, X: Floats2d) -> Tuple[Floats2d, Floats2d, Floats2d]: # TODO: Do mean methods mu: Floats2d = X.mean(axis=1, keepdims=True) var: Floats2d = X.var(axis=1, keepdims=True) + 1e-08 return cast(Floats2d, ops.asarray_f([X.shape[1]])), mu, var def _get_d_moments( ops: Ops, dy: Floats2d, X: Floats2d, mu: Floats2d ) -> Tuple[Floats2d, Floats2d, Floats2d]: dist = X - mu return ( dist, ops.xp.sum(dy, axis=1, keepdims=True), ops.xp.sum(dy * dist, axis=1, keepdims=True), )