39 lines
988 B
Python
39 lines
988 B
Python
import numpy
|
|
import pytest
|
|
|
|
from thinc import registry
|
|
from thinc.api import (
|
|
NumpyOps,
|
|
glorot_uniform_init,
|
|
normal_init,
|
|
uniform_init,
|
|
zero_init,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"init_func", [glorot_uniform_init, zero_init, uniform_init, normal_init]
|
|
)
|
|
def test_initializer_func_setup(init_func):
|
|
ops = NumpyOps()
|
|
data = numpy.ndarray([1, 2, 3, 4], dtype="f")
|
|
result = init_func(ops, data.shape)
|
|
assert not numpy.array_equal(data, result)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name,kwargs",
|
|
[
|
|
("glorot_uniform_init.v1", {}),
|
|
("zero_init.v1", {}),
|
|
("uniform_init.v1", {"lo": -0.5, "hi": 0.5}),
|
|
("normal_init.v1", {"mean": 0.1}),
|
|
],
|
|
)
|
|
def test_initializer_from_config(name, kwargs):
|
|
"""Test that initializers are loaded and configured correctly from registry
|
|
(as partials)."""
|
|
cfg = {"test": {"@initializers": name, **kwargs}}
|
|
func = registry.resolve(cfg)["test"]
|
|
func(NumpyOps(), (1, 2, 3, 4))
|