122 lines
3.5 KiB
Python
122 lines
3.5 KiB
Python
import contextlib
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import numpy
|
|
import pytest
|
|
|
|
from thinc.api import ArgsKwargs, Linear, Padded, Ragged
|
|
from thinc.util import has_cupy, is_cupy_array, is_numpy_array
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def make_tempdir():
|
|
d = Path(tempfile.mkdtemp())
|
|
yield d
|
|
shutil.rmtree(str(d))
|
|
|
|
|
|
def get_model(W_b_input, cls=Linear):
|
|
W, b, input_ = W_b_input
|
|
nr_out, nr_in = W.shape
|
|
model = cls(nr_out, nr_in)
|
|
model.set_param("W", W)
|
|
model.set_param("b", b)
|
|
model.initialize()
|
|
return model
|
|
|
|
|
|
def get_shape(W_b_input):
|
|
W, b, input_ = W_b_input
|
|
return input_.shape[0], W.shape[0], W.shape[1]
|
|
|
|
|
|
def get_data_checker(inputs):
|
|
if isinstance(inputs, Ragged):
|
|
return assert_raggeds_match
|
|
elif isinstance(inputs, Padded):
|
|
return assert_paddeds_match
|
|
elif isinstance(inputs, list):
|
|
return assert_lists_match
|
|
elif isinstance(inputs, tuple) and len(inputs) == 4:
|
|
return assert_padded_data_match
|
|
elif isinstance(inputs, tuple) and len(inputs) == 2:
|
|
return assert_ragged_data_match
|
|
else:
|
|
return assert_arrays_match
|
|
|
|
|
|
def assert_arrays_match(X, Y):
|
|
assert X.dtype == Y.dtype
|
|
# Transformations are allowed to change last dimension, but not batch size.
|
|
assert X.shape[0] == Y.shape[0]
|
|
return True
|
|
|
|
|
|
def assert_lists_match(X, Y):
|
|
assert isinstance(X, list)
|
|
assert isinstance(Y, list)
|
|
assert len(X) == len(Y)
|
|
for x, y in zip(X, Y):
|
|
assert_arrays_match(x, y)
|
|
return True
|
|
|
|
|
|
def assert_raggeds_match(X, Y):
|
|
assert isinstance(X, Ragged)
|
|
assert isinstance(Y, Ragged)
|
|
assert_arrays_match(X.lengths, Y.lengths)
|
|
assert_arrays_match(X.data, Y.data)
|
|
return True
|
|
|
|
|
|
def assert_paddeds_match(X, Y):
|
|
assert isinstance(X, Padded)
|
|
assert isinstance(Y, Padded)
|
|
assert_arrays_match(X.size_at_t, Y.size_at_t)
|
|
assert assert_arrays_match(X.lengths, Y.lengths)
|
|
assert assert_arrays_match(X.indices, Y.indices)
|
|
assert X.data.dtype == Y.data.dtype
|
|
assert X.data.shape[1] == Y.data.shape[1]
|
|
assert X.data.shape[0] == Y.data.shape[0]
|
|
return True
|
|
|
|
|
|
def assert_padded_data_match(X, Y):
|
|
return assert_paddeds_match(Padded(*X), Padded(*Y))
|
|
|
|
|
|
def assert_ragged_data_match(X, Y):
|
|
return assert_raggeds_match(Ragged(*X), Ragged(*Y))
|
|
|
|
|
|
def check_input_converters(Y, backprop, data, n_args, kwargs_keys, type_):
|
|
assert isinstance(Y, ArgsKwargs)
|
|
assert len(Y.args) == n_args
|
|
assert list(Y.kwargs.keys()) == kwargs_keys
|
|
assert all(isinstance(arg, type_) for arg in Y.args)
|
|
assert all(isinstance(arg, type_) for arg in Y.kwargs.values())
|
|
dX = backprop(Y)
|
|
|
|
def is_supported_backend_array(arr):
|
|
return is_cupy_array(arr) or is_numpy_array(arr)
|
|
|
|
input_type = type(data) if not isinstance(data, list) else tuple
|
|
assert isinstance(dX, input_type) or is_supported_backend_array(dX)
|
|
|
|
if isinstance(data, dict):
|
|
assert list(dX.keys()) == kwargs_keys
|
|
assert all(is_supported_backend_array(arr) for arr in dX.values())
|
|
elif isinstance(data, (list, tuple)):
|
|
assert isinstance(dX, tuple)
|
|
assert all(is_supported_backend_array(arr) for arr in dX)
|
|
elif isinstance(data, ArgsKwargs):
|
|
assert len(dX.args) == n_args
|
|
assert list(dX.kwargs.keys()) == kwargs_keys
|
|
|
|
assert all(is_supported_backend_array(arg) for arg in dX.args)
|
|
assert all(is_supported_backend_array(arg) for arg in dX.kwargs.values())
|
|
elif not isinstance(data, numpy.ndarray):
|
|
pytest.fail(f"Bad data type: {dX}")
|