ai-content-maker/.venv/Lib/site-packages/thinc/tests/layers/test_mxnet_wrapper.py

217 lines
6.1 KiB
Python

from typing import cast
import numpy
import pytest
from thinc.api import (
Adam,
ArgsKwargs,
Model,
MXNetWrapper,
Ops,
get_current_ops,
mxnet2xp,
xp2mxnet,
)
from thinc.compat import has_cupy_gpu, has_mxnet
from thinc.types import Array1d, Array2d, IntsXd
from thinc.util import to_categorical
from ..util import check_input_converters, make_tempdir
@pytest.fixture
def n_hidden() -> int:
return 12
@pytest.fixture
def input_size() -> int:
return 784
@pytest.fixture
def n_classes() -> int:
return 10
@pytest.fixture
def answer() -> int:
return 1
@pytest.fixture
def X(input_size: int) -> Array2d:
ops: Ops = get_current_ops()
return cast(Array2d, ops.alloc(shape=(1, input_size)))
@pytest.fixture
def Y(answer: int, n_classes: int) -> Array2d:
ops: Ops = get_current_ops()
return cast(
Array2d,
to_categorical(cast(IntsXd, ops.asarray([answer])), n_classes=n_classes),
)
@pytest.fixture
def mx_model(n_hidden: int, input_size: int, X: Array2d):
import mxnet as mx
mx_model = mx.gluon.nn.Sequential()
mx_model.add(
mx.gluon.nn.Dense(n_hidden),
mx.gluon.nn.LayerNorm(),
mx.gluon.nn.Dense(n_hidden, activation="relu"),
mx.gluon.nn.LayerNorm(),
mx.gluon.nn.Dense(10, activation="softrelu"),
)
mx_model.initialize()
return mx_model
@pytest.fixture
def model(mx_model) -> Model[Array2d, Array2d]:
return MXNetWrapper(mx_model)
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_roundtrip_conversion():
import mxnet as mx
xp_tensor = numpy.zeros((2, 3), dtype="f")
mx_tensor = xp2mxnet(xp_tensor)
assert isinstance(mx_tensor, mx.nd.NDArray)
new_xp_tensor = mxnet2xp(mx_tensor)
assert numpy.array_equal(xp_tensor, new_xp_tensor)
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_gluon_sequential():
import mxnet as mx
mx_model = mx.gluon.nn.Sequential()
mx_model.add(mx.gluon.nn.Dense(12))
wrapped = MXNetWrapper(mx_model)
assert isinstance(wrapped, Model)
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_built_model(
model: Model[Array2d, Array2d], X: Array2d, Y: Array1d
):
# built models are validated more and can perform useful operations:
assert model.predict(X) is not None
# They can de/serialized
assert model.from_bytes(model.to_bytes()) is not None
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_predict(model: Model[Array2d, Array2d], X: Array2d):
model.predict(X)
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_train_overfits(
model: Model[Array2d, Array2d], X: Array2d, Y: Array1d, answer: int
):
optimizer = Adam()
for i in range(100):
guesses, backprop = model(X, is_train=True)
d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
predicted = model.predict(X).argmax()
assert predicted == answer
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_can_copy_model(model: Model[Array2d, Array2d], X: Array2d):
model.predict(X)
copy: Model[Array2d, Array2d] = model.copy()
assert copy is not None
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_to_bytes(model: Model[Array2d, Array2d], X: Array2d):
model.predict(X)
# And can be serialized
model_bytes = model.to_bytes()
assert model_bytes is not None
model.from_bytes(model_bytes)
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_to_from_disk(model: Model[Array2d, Array2d], X: Array2d):
model.predict(X)
with make_tempdir() as tmp_path:
model_file = tmp_path / "model.bytes"
model.to_disk(model_file)
another_model = model.from_disk(model_file)
assert another_model is not None
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_from_bytes(model: Model[Array2d, Array2d], X: Array2d):
model.predict(X)
model_bytes = model.to_bytes()
another_model = model.from_bytes(model_bytes)
assert another_model is not None
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_to_cpu(mx_model, X: Array2d):
model = MXNetWrapper(mx_model)
model.predict(X)
model.to_cpu()
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
@pytest.mark.skipif(not has_cupy_gpu, reason="needs GPU/cupy")
def test_mxnet_wrapper_to_gpu(model: Model[Array2d, Array2d], X: Array2d):
model.predict(X)
model.to_gpu(0)
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
@pytest.mark.parametrize(
"data,n_args,kwargs_keys",
[
# fmt: off
(numpy.zeros((2, 3), dtype="f"), 1, []),
([numpy.zeros((2, 3), dtype="f"), numpy.zeros((2, 3), dtype="f")], 2, []),
((numpy.zeros((2, 3), dtype="f"), numpy.zeros((2, 3), dtype="f")), 2, []),
({"a": numpy.zeros((2, 3), dtype="f"), "b": numpy.zeros((2, 3), dtype="f")}, 0, ["a", "b"]),
(ArgsKwargs((numpy.zeros((2, 3), dtype="f"), numpy.zeros((2, 3), dtype="f")), {"c": numpy.zeros((2, 3), dtype="f")}), 2, ["c"]),
# fmt: on
],
)
def test_mxnet_wrapper_convert_inputs(data, n_args, kwargs_keys):
import mxnet as mx
mx_model = mx.gluon.nn.Sequential()
mx_model.add(mx.gluon.nn.Dense(12))
mx_model.initialize()
model = MXNetWrapper(mx_model)
convert_inputs = model.attrs["convert_inputs"]
Y, backprop = convert_inputs(model, data, is_train=True)
check_input_converters(Y, backprop, data, n_args, kwargs_keys, mx.nd.NDArray)
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_thinc_model_subclass(mx_model):
class CustomModel(Model):
def fn(self) -> int:
return 1337
model = MXNetWrapper(mx_model, model_class=CustomModel)
assert isinstance(model, CustomModel)
assert model.fn() == 1337
@pytest.mark.skipif(not has_mxnet, reason="needs MXNet")
def test_mxnet_wrapper_thinc_set_model_name(mx_model):
model = MXNetWrapper(mx_model, model_name="cool")
assert model.name == "cool"