94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
from typing import Callable, Optional, TypeVar
|
|
|
|
from ..config import registry
|
|
from ..model import Model
|
|
from ..types import Floats2d
|
|
|
|
InT = TypeVar("InT")
|
|
OutT = TypeVar("OutT")
|
|
|
|
|
|
@registry.layers("resizable.v1")
|
|
def resizable(layer, resize_layer: Callable) -> Model[InT, OutT]:
|
|
"""Container that holds one layer that can change dimensions."""
|
|
return Model(
|
|
f"resizable({layer.name})",
|
|
forward,
|
|
init=init,
|
|
layers=[layer],
|
|
attrs={"resize_layer": resize_layer},
|
|
dims={name: layer.maybe_get_dim(name) for name in layer.dim_names},
|
|
)
|
|
|
|
|
|
def forward(model: Model[InT, OutT], X: InT, is_train: bool):
|
|
layer = model.layers[0]
|
|
Y, callback = layer(X, is_train=is_train)
|
|
|
|
def backprop(dY: OutT) -> InT:
|
|
return callback(dY)
|
|
|
|
return Y, backprop
|
|
|
|
|
|
def init(
|
|
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
|
|
) -> None:
|
|
layer = model.layers[0]
|
|
layer.initialize(X, Y)
|
|
|
|
|
|
def resize_model(model: Model[InT, OutT], new_nO):
|
|
old_layer = model.layers[0]
|
|
new_layer = model.attrs["resize_layer"](old_layer, new_nO)
|
|
model.layers[0] = new_layer
|
|
return model
|
|
|
|
|
|
def resize_linear_weighted(
|
|
layer: Model[Floats2d, Floats2d], new_nO, *, fill_defaults=None
|
|
) -> Model[Floats2d, Floats2d]:
|
|
"""Create a resized copy of a layer that has parameters W and b and dimensions nO and nI."""
|
|
assert not layer.layers
|
|
assert not layer.ref_names
|
|
assert not layer.shims
|
|
|
|
# return the original layer if it wasn't initialized or if nO didn't change
|
|
if layer.has_dim("nO") is None:
|
|
layer.set_dim("nO", new_nO)
|
|
return layer
|
|
elif new_nO == layer.get_dim("nO"):
|
|
return layer
|
|
elif layer.has_dim("nI") is None:
|
|
layer.set_dim("nO", new_nO, force=True)
|
|
return layer
|
|
|
|
dims = {name: layer.maybe_get_dim(name) for name in layer.dim_names}
|
|
dims["nO"] = new_nO
|
|
new_layer: Model[Floats2d, Floats2d] = Model(
|
|
layer.name,
|
|
layer._func,
|
|
dims=dims,
|
|
params={name: None for name in layer.param_names},
|
|
init=layer.init,
|
|
attrs=layer.attrs,
|
|
refs={},
|
|
ops=layer.ops,
|
|
)
|
|
new_layer.initialize()
|
|
for name in layer.param_names:
|
|
if layer.has_param(name):
|
|
filler = 0 if not fill_defaults else fill_defaults.get(name, 0)
|
|
_resize_parameter(name, layer, new_layer, filler=filler)
|
|
return new_layer
|
|
|
|
|
|
def _resize_parameter(name, layer, new_layer, filler=0):
|
|
larger = new_layer.get_param(name)
|
|
smaller = layer.get_param(name)
|
|
# copy the original weights
|
|
larger[: len(smaller)] = smaller
|
|
# set the new weights
|
|
larger[len(smaller) :] = filler
|
|
new_layer.set_param(name, larger)
|