ai-content-maker/.venv/Lib/site-packages/thinc/layers/resizable.py

94 lines
2.7 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from typing import Callable, Optional, TypeVar
from ..config import registry
from ..model import Model
from ..types import Floats2d
InT = TypeVar("InT")
OutT = TypeVar("OutT")
@registry.layers("resizable.v1")
def resizable(layer, resize_layer: Callable) -> Model[InT, OutT]:
"""Container that holds one layer that can change dimensions."""
return Model(
f"resizable({layer.name})",
forward,
init=init,
layers=[layer],
attrs={"resize_layer": resize_layer},
dims={name: layer.maybe_get_dim(name) for name in layer.dim_names},
)
def forward(model: Model[InT, OutT], X: InT, is_train: bool):
layer = model.layers[0]
Y, callback = layer(X, is_train=is_train)
def backprop(dY: OutT) -> InT:
return callback(dY)
return Y, backprop
def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
layer = model.layers[0]
layer.initialize(X, Y)
def resize_model(model: Model[InT, OutT], new_nO):
old_layer = model.layers[0]
new_layer = model.attrs["resize_layer"](old_layer, new_nO)
model.layers[0] = new_layer
return model
def resize_linear_weighted(
layer: Model[Floats2d, Floats2d], new_nO, *, fill_defaults=None
) -> Model[Floats2d, Floats2d]:
"""Create a resized copy of a layer that has parameters W and b and dimensions nO and nI."""
assert not layer.layers
assert not layer.ref_names
assert not layer.shims
# return the original layer if it wasn't initialized or if nO didn't change
if layer.has_dim("nO") is None:
layer.set_dim("nO", new_nO)
return layer
elif new_nO == layer.get_dim("nO"):
return layer
elif layer.has_dim("nI") is None:
layer.set_dim("nO", new_nO, force=True)
return layer
dims = {name: layer.maybe_get_dim(name) for name in layer.dim_names}
dims["nO"] = new_nO
new_layer: Model[Floats2d, Floats2d] = Model(
layer.name,
layer._func,
dims=dims,
params={name: None for name in layer.param_names},
init=layer.init,
attrs=layer.attrs,
refs={},
ops=layer.ops,
)
new_layer.initialize()
for name in layer.param_names:
if layer.has_param(name):
filler = 0 if not fill_defaults else fill_defaults.get(name, 0)
_resize_parameter(name, layer, new_layer, filler=filler)
return new_layer
def _resize_parameter(name, layer, new_layer, filler=0):
larger = new_layer.get_param(name)
smaller = layer.get_param(name)
# copy the original weights
larger[: len(smaller)] = smaller
# set the new weights
larger[len(smaller) :] = filler
new_layer.set_param(name, larger)