35 lines
874 B
Python
35 lines
874 B
Python
from functools import partial
|
|
|
|
import pytest
|
|
|
|
from thinc.api import Linear, resizable
|
|
from thinc.layers.resizable import resize_linear_weighted, resize_model
|
|
|
|
|
|
@pytest.fixture
|
|
def model():
|
|
output_layer = Linear(nO=None, nI=None)
|
|
fill_defaults = {"b": 0, "W": 0}
|
|
model = resizable(
|
|
output_layer,
|
|
resize_layer=partial(resize_linear_weighted, fill_defaults=fill_defaults),
|
|
)
|
|
return model
|
|
|
|
|
|
def test_resizable_linear_default_name(model):
|
|
assert model.name == "resizable(linear)"
|
|
|
|
|
|
def test_resize_model(model):
|
|
"""Test that resizing the model doesn't cause an exception."""
|
|
resize_model(model, new_nO=10)
|
|
resize_model(model, new_nO=11)
|
|
|
|
model.set_dim("nO", 0, force=True)
|
|
resize_model(model, new_nO=10)
|
|
|
|
model.set_dim("nI", 10, force=True)
|
|
model.set_dim("nO", 0, force=True)
|
|
resize_model(model, new_nO=10)
|