672 lines
21 KiB
Python
672 lines
21 KiB
Python
import contextlib
|
|
import functools
|
|
import inspect
|
|
import os
|
|
import platform
|
|
import random
|
|
import tempfile
|
|
import threading
|
|
from contextvars import ContextVar
|
|
from dataclasses import dataclass
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
import numpy
|
|
from packaging.version import Version
|
|
|
|
try:
|
|
from pydantic.v1 import ValidationError, create_model
|
|
except ImportError:
|
|
from pydantic import ValidationError, create_model # type: ignore
|
|
|
|
from wasabi import table
|
|
|
|
from .compat import (
|
|
cupy,
|
|
cupy_from_dlpack,
|
|
has_cupy,
|
|
has_cupy_gpu,
|
|
has_gpu,
|
|
has_mxnet,
|
|
has_tensorflow,
|
|
has_torch,
|
|
has_torch_cuda_gpu,
|
|
has_torch_mps,
|
|
)
|
|
from .compat import mxnet as mx
|
|
from .compat import tensorflow as tf
|
|
from .compat import torch
|
|
|
|
DATA_VALIDATION: ContextVar[bool] = ContextVar("DATA_VALIDATION", default=False)
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from . import types # noqa: E402
|
|
from .types import ArgsKwargs, ArrayXd, FloatsXd, IntsXd, Padded, Ragged # noqa: E402
|
|
|
|
if TYPE_CHECKING:
|
|
from .api import Ops
|
|
|
|
|
|
def get_torch_default_device() -> "torch.device":
|
|
if torch is None:
|
|
raise ValueError("Cannot get default Torch device when Torch is not available.")
|
|
|
|
from .backends import get_current_ops
|
|
from .backends.cupy_ops import CupyOps
|
|
from .backends.mps_ops import MPSOps
|
|
|
|
ops = get_current_ops()
|
|
if isinstance(ops, CupyOps):
|
|
device_id = torch.cuda.current_device()
|
|
return torch.device(f"cuda:{device_id}")
|
|
elif isinstance(ops, MPSOps):
|
|
return torch.device("mps")
|
|
|
|
return torch.device("cpu")
|
|
|
|
|
|
def get_array_module(arr): # pragma: no cover
|
|
if is_numpy_array(arr):
|
|
return numpy
|
|
elif is_cupy_array(arr):
|
|
return cupy
|
|
else:
|
|
raise ValueError(
|
|
"Only numpy and cupy arrays are supported"
|
|
f", but found {type(arr)} instead. If "
|
|
"get_array_module module wasn't called "
|
|
"directly, this might indicate a bug in Thinc."
|
|
)
|
|
|
|
|
|
def gpu_is_available():
|
|
return has_gpu
|
|
|
|
|
|
def fix_random_seed(seed: int = 0) -> None: # pragma: no cover
|
|
"""Set the random seed across random, numpy.random and cupy.random."""
|
|
random.seed(seed)
|
|
numpy.random.seed(seed)
|
|
if has_torch:
|
|
torch.manual_seed(seed)
|
|
if has_cupy_gpu:
|
|
cupy.random.seed(seed)
|
|
if has_torch and has_torch_cuda_gpu:
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def is_xp_array(obj: Any) -> bool:
|
|
"""Check whether an object is a numpy or cupy array."""
|
|
return is_numpy_array(obj) or is_cupy_array(obj)
|
|
|
|
|
|
def is_cupy_array(obj: Any) -> bool: # pragma: no cover
|
|
"""Check whether an object is a cupy array."""
|
|
if not has_cupy:
|
|
return False
|
|
elif isinstance(obj, cupy.ndarray):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_numpy_array(obj: Any) -> bool:
|
|
"""Check whether an object is a numpy array."""
|
|
if isinstance(obj, numpy.ndarray):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_torch_array(obj: Any) -> bool: # pragma: no cover
|
|
if torch is None:
|
|
return False
|
|
elif isinstance(obj, torch.Tensor):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_torch_cuda_array(obj: Any) -> bool: # pragma: no cover
|
|
return is_torch_array(obj) and obj.is_cuda
|
|
|
|
|
|
def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover
|
|
return is_torch_cuda_array(obj) or is_torch_mps_array(obj)
|
|
|
|
|
|
def is_torch_mps_array(obj: Any) -> bool: # pragma: no cover
|
|
return is_torch_array(obj) and hasattr(obj, "is_mps") and obj.is_mps
|
|
|
|
|
|
def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover
|
|
if not has_tensorflow:
|
|
return False
|
|
elif isinstance(obj, tf.Tensor): # type: ignore
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_tensorflow_gpu_array(obj: Any) -> bool: # pragma: no cover
|
|
return is_tensorflow_array(obj) and "GPU:" in obj.device
|
|
|
|
|
|
def is_mxnet_array(obj: Any) -> bool: # pragma: no cover
|
|
if not has_mxnet:
|
|
return False
|
|
elif isinstance(obj, mx.nd.NDArray): # type: ignore
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_mxnet_gpu_array(obj: Any) -> bool: # pragma: no cover
|
|
return is_mxnet_array(obj) and obj.context.device_type != "cpu"
|
|
|
|
|
|
def to_numpy(data): # pragma: no cover
|
|
if isinstance(data, numpy.ndarray):
|
|
return data
|
|
elif has_cupy and isinstance(data, cupy.ndarray):
|
|
return data.get()
|
|
else:
|
|
return numpy.array(data)
|
|
|
|
|
|
def set_active_gpu(gpu_id: int) -> "cupy.cuda.Device": # pragma: no cover
|
|
"""Set the current GPU device for cupy and torch (if available)."""
|
|
if not has_cupy_gpu:
|
|
raise ValueError("No CUDA GPU devices detected")
|
|
|
|
device = cupy.cuda.device.Device(gpu_id)
|
|
device.use()
|
|
|
|
if has_torch_cuda_gpu:
|
|
torch.cuda.set_device(gpu_id)
|
|
|
|
return device
|
|
|
|
|
|
def require_cpu() -> bool: # pragma: no cover
|
|
"""Use CPU through best available backend."""
|
|
from .backends import get_ops, set_current_ops
|
|
|
|
ops = get_ops("cpu")
|
|
set_current_ops(ops)
|
|
|
|
return True
|
|
|
|
|
|
def prefer_gpu(gpu_id: int = 0) -> bool: # pragma: no cover
|
|
"""Use GPU if it's available. Returns True if so, False otherwise."""
|
|
if has_gpu:
|
|
require_gpu(gpu_id=gpu_id)
|
|
return has_gpu
|
|
|
|
|
|
def require_gpu(gpu_id: int = 0) -> bool: # pragma: no cover
|
|
from .backends import CupyOps, MPSOps, set_current_ops
|
|
|
|
if platform.system() == "Darwin" and not has_torch_mps:
|
|
if has_torch:
|
|
raise ValueError("Cannot use GPU, installed PyTorch does not support MPS")
|
|
raise ValueError("Cannot use GPU, PyTorch is not installed")
|
|
elif platform.system() != "Darwin" and not has_cupy:
|
|
raise ValueError("Cannot use GPU, CuPy is not installed")
|
|
elif not has_gpu:
|
|
raise ValueError("No GPU devices detected")
|
|
|
|
if has_cupy_gpu:
|
|
set_current_ops(CupyOps())
|
|
set_active_gpu(gpu_id)
|
|
else:
|
|
set_current_ops(MPSOps())
|
|
|
|
return True
|
|
|
|
|
|
def copy_array(dst: ArrayXd, src: ArrayXd) -> None: # pragma: no cover
|
|
if isinstance(dst, numpy.ndarray) and isinstance(src, numpy.ndarray):
|
|
dst[:] = src
|
|
elif is_cupy_array(dst):
|
|
src = cupy.array(src, copy=False)
|
|
cupy.copyto(dst, src)
|
|
else:
|
|
numpy.copyto(dst, src) # type: ignore
|
|
|
|
|
|
def to_categorical(
|
|
Y: IntsXd,
|
|
n_classes: Optional[int] = None,
|
|
*,
|
|
label_smoothing: float = 0.0,
|
|
) -> FloatsXd:
|
|
if n_classes is None:
|
|
n_classes = int(numpy.max(Y) + 1) # type: ignore
|
|
|
|
if label_smoothing < 0.0:
|
|
raise ValueError(
|
|
"Label-smoothing parameter has to be greater than or equal to 0"
|
|
)
|
|
|
|
if label_smoothing == 0.0:
|
|
if n_classes == 0:
|
|
raise ValueError("n_classes should be at least 1")
|
|
nongold_prob = 0.0
|
|
else:
|
|
if not n_classes > 1:
|
|
raise ValueError(
|
|
"n_classes should be greater than 1 when label smoothing is enabled,"
|
|
f"but {n_classes} was provided."
|
|
)
|
|
nongold_prob = label_smoothing / (n_classes - 1)
|
|
|
|
max_smooth = (n_classes - 1) / n_classes
|
|
if n_classes > 1 and label_smoothing >= max_smooth:
|
|
raise ValueError(
|
|
f"For {n_classes} classes "
|
|
"label_smoothing parameter has to be less than "
|
|
f"{max_smooth}, but found {label_smoothing}."
|
|
)
|
|
|
|
xp = get_array_module(Y)
|
|
label_distr = xp.full((n_classes, n_classes), nongold_prob, dtype="float32")
|
|
xp.fill_diagonal(label_distr, 1 - label_smoothing)
|
|
return label_distr[Y]
|
|
|
|
|
|
def get_width(
|
|
X: Union[ArrayXd, Ragged, Padded, Sequence[ArrayXd]], *, dim: int = -1
|
|
) -> int:
|
|
"""Infer the 'width' of a batch of data, which could be any of: Array,
|
|
Ragged, Padded or Sequence of Arrays.
|
|
"""
|
|
if isinstance(X, Ragged):
|
|
return get_width(X.data, dim=dim)
|
|
elif isinstance(X, Padded):
|
|
return get_width(X.data, dim=dim)
|
|
elif hasattr(X, "shape") and hasattr(X, "ndim"):
|
|
X = cast(ArrayXd, X)
|
|
if len(X.shape) == 0:
|
|
return 0
|
|
elif len(X.shape) == 1:
|
|
return int(X.max()) + 1
|
|
else:
|
|
return X.shape[dim]
|
|
elif isinstance(X, (list, tuple)):
|
|
if len(X) == 0:
|
|
return 0
|
|
else:
|
|
return get_width(X[0], dim=dim)
|
|
else:
|
|
err = "Cannot get width of object: has neither shape nor __getitem__"
|
|
raise ValueError(err)
|
|
|
|
|
|
def assert_tensorflow_installed() -> None: # pragma: no cover
|
|
"""Raise an ImportError if TensorFlow is not installed."""
|
|
template = "TensorFlow support requires {pkg}: pip install thinc[tensorflow]\n\nEnable TensorFlow support with thinc.api.enable_tensorflow()"
|
|
if not has_tensorflow:
|
|
raise ImportError(template.format(pkg="tensorflow>=2.0.0,<2.6.0"))
|
|
|
|
|
|
def assert_mxnet_installed() -> None: # pragma: no cover
|
|
"""Raise an ImportError if MXNet is not installed."""
|
|
if not has_mxnet:
|
|
raise ImportError(
|
|
"MXNet support requires mxnet: pip install thinc[mxnet]\n\nEnable MXNet support with thinc.api.enable_mxnet()"
|
|
)
|
|
|
|
|
|
def assert_pytorch_installed() -> None: # pragma: no cover
|
|
"""Raise an ImportError if PyTorch is not installed."""
|
|
if not has_torch:
|
|
raise ImportError("PyTorch support requires torch: pip install thinc[torch]")
|
|
|
|
|
|
def convert_recursive(
|
|
is_match: Callable[[Any], bool], convert_item: Callable[[Any], Any], obj: Any
|
|
) -> Any:
|
|
"""Either convert a single value if it matches a given function, or
|
|
recursively walk over potentially nested lists, tuples and dicts applying
|
|
the conversion, and returns the same type. Also supports the ArgsKwargs
|
|
dataclass.
|
|
"""
|
|
if is_match(obj):
|
|
return convert_item(obj)
|
|
elif isinstance(obj, ArgsKwargs):
|
|
converted = convert_recursive(is_match, convert_item, list(obj.items()))
|
|
return ArgsKwargs.from_items(converted)
|
|
elif isinstance(obj, dict):
|
|
converted = {}
|
|
for key, value in obj.items():
|
|
key = convert_recursive(is_match, convert_item, key)
|
|
value = convert_recursive(is_match, convert_item, value)
|
|
converted[key] = value
|
|
return converted
|
|
elif isinstance(obj, list):
|
|
return [convert_recursive(is_match, convert_item, item) for item in obj]
|
|
elif isinstance(obj, tuple):
|
|
return tuple(convert_recursive(is_match, convert_item, item) for item in obj)
|
|
else:
|
|
return obj
|
|
|
|
|
|
def iterate_recursive(is_match: Callable[[Any], bool], obj: Any) -> Any:
|
|
"""Either yield a single value if it matches a given function, or recursively
|
|
walk over potentially nested lists, tuples and dicts yielding matching
|
|
values. Also supports the ArgsKwargs dataclass.
|
|
"""
|
|
if is_match(obj):
|
|
yield obj
|
|
elif isinstance(obj, ArgsKwargs):
|
|
yield from iterate_recursive(is_match, list(obj.items()))
|
|
elif isinstance(obj, dict):
|
|
for key, value in obj.items():
|
|
yield from iterate_recursive(is_match, key)
|
|
yield from iterate_recursive(is_match, value)
|
|
elif isinstance(obj, list) or isinstance(obj, tuple):
|
|
for item in obj:
|
|
yield from iterate_recursive(is_match, item)
|
|
|
|
|
|
def xp2torch(
|
|
xp_tensor: ArrayXd,
|
|
requires_grad: bool = False,
|
|
device: Optional["torch.device"] = None,
|
|
) -> "torch.Tensor": # pragma: no cover
|
|
"""Convert a numpy or cupy tensor to a PyTorch tensor."""
|
|
assert_pytorch_installed()
|
|
|
|
if device is None:
|
|
device = get_torch_default_device()
|
|
|
|
if hasattr(xp_tensor, "toDlpack"):
|
|
dlpack_tensor = xp_tensor.toDlpack() # type: ignore
|
|
torch_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor)
|
|
elif hasattr(xp_tensor, "__dlpack__"):
|
|
torch_tensor = torch.utils.dlpack.from_dlpack(xp_tensor)
|
|
else:
|
|
torch_tensor = torch.from_numpy(xp_tensor)
|
|
|
|
torch_tensor = torch_tensor.to(device)
|
|
|
|
if requires_grad:
|
|
torch_tensor.requires_grad_()
|
|
|
|
return torch_tensor
|
|
|
|
|
|
def torch2xp(
|
|
torch_tensor: "torch.Tensor", *, ops: Optional["Ops"] = None
|
|
) -> ArrayXd: # pragma: no cover
|
|
"""Convert a torch tensor to a numpy or cupy tensor depending on the `ops` parameter.
|
|
If `ops` is `None`, the type of the resultant tensor will be determined by the source tensor's device.
|
|
"""
|
|
from .api import NumpyOps
|
|
|
|
assert_pytorch_installed()
|
|
if is_torch_cuda_array(torch_tensor):
|
|
if isinstance(ops, NumpyOps):
|
|
return torch_tensor.detach().cpu().numpy()
|
|
else:
|
|
return cupy_from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor))
|
|
else:
|
|
if isinstance(ops, NumpyOps) or ops is None:
|
|
return torch_tensor.detach().cpu().numpy()
|
|
else:
|
|
return cupy.asarray(torch_tensor)
|
|
|
|
|
|
def xp2tensorflow(
|
|
xp_tensor: ArrayXd, requires_grad: bool = False, as_variable: bool = False
|
|
) -> "tf.Tensor": # type: ignore # pragma: no cover
|
|
"""Convert a numpy or cupy tensor to a TensorFlow Tensor or Variable"""
|
|
assert_tensorflow_installed()
|
|
if hasattr(xp_tensor, "toDlpack"):
|
|
dlpack_tensor = xp_tensor.toDlpack() # type: ignore
|
|
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor) # type: ignore
|
|
elif hasattr(xp_tensor, "__dlpack__"):
|
|
dlpack_tensor = xp_tensor.__dlpack__() # type: ignore
|
|
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor) # type: ignore
|
|
else:
|
|
tf_tensor = tf.convert_to_tensor(xp_tensor) # type: ignore
|
|
if as_variable:
|
|
# tf.Variable() automatically puts in GPU if available.
|
|
# So we need to control it using the context manager
|
|
with tf.device(tf_tensor.device): # type: ignore
|
|
tf_tensor = tf.Variable(tf_tensor, trainable=requires_grad) # type: ignore
|
|
if requires_grad is False and as_variable is False:
|
|
# tf.stop_gradient() automatically puts in GPU if available.
|
|
# So we need to control it using the context manager
|
|
with tf.device(tf_tensor.device): # type: ignore
|
|
tf_tensor = tf.stop_gradient(tf_tensor) # type: ignore
|
|
return tf_tensor
|
|
|
|
|
|
def tensorflow2xp(
|
|
tf_tensor: "tf.Tensor", *, ops: Optional["Ops"] = None # type: ignore
|
|
) -> ArrayXd: # pragma: no cover
|
|
"""Convert a Tensorflow tensor to numpy or cupy tensor depending on the `ops` parameter.
|
|
If `ops` is `None`, the type of the resultant tensor will be determined by the source tensor's device.
|
|
"""
|
|
from .api import NumpyOps
|
|
|
|
assert_tensorflow_installed()
|
|
if is_tensorflow_gpu_array(tf_tensor):
|
|
if isinstance(ops, NumpyOps):
|
|
return tf_tensor.numpy()
|
|
else:
|
|
dlpack_tensor = tf.experimental.dlpack.to_dlpack(tf_tensor) # type: ignore
|
|
return cupy_from_dlpack(dlpack_tensor)
|
|
else:
|
|
if isinstance(ops, NumpyOps) or ops is None:
|
|
return tf_tensor.numpy()
|
|
else:
|
|
return cupy.asarray(tf_tensor.numpy())
|
|
|
|
|
|
def xp2mxnet(
|
|
xp_tensor: ArrayXd, requires_grad: bool = False
|
|
) -> "mx.nd.NDArray": # type: ignore # pragma: no cover
|
|
"""Convert a numpy or cupy tensor to a MXNet tensor."""
|
|
assert_mxnet_installed()
|
|
if hasattr(xp_tensor, "toDlpack"):
|
|
dlpack_tensor = xp_tensor.toDlpack() # type: ignore
|
|
mx_tensor = mx.nd.from_dlpack(dlpack_tensor) # type: ignore
|
|
else:
|
|
mx_tensor = mx.nd.from_numpy(xp_tensor) # type: ignore
|
|
if requires_grad:
|
|
mx_tensor.attach_grad()
|
|
return mx_tensor
|
|
|
|
|
|
def mxnet2xp(
|
|
mx_tensor: "mx.nd.NDArray", *, ops: Optional["Ops"] = None # type: ignore
|
|
) -> ArrayXd: # pragma: no cover
|
|
"""Convert a MXNet tensor to a numpy or cupy tensor."""
|
|
from .api import NumpyOps
|
|
|
|
assert_mxnet_installed()
|
|
if is_mxnet_gpu_array(mx_tensor):
|
|
if isinstance(ops, NumpyOps):
|
|
return mx_tensor.detach().asnumpy()
|
|
else:
|
|
return cupy_from_dlpack(mx_tensor.to_dlpack_for_write())
|
|
else:
|
|
if isinstance(ops, NumpyOps) or ops is None:
|
|
return mx_tensor.detach().asnumpy()
|
|
else:
|
|
return cupy.asarray(mx_tensor.asnumpy())
|
|
|
|
|
|
# This is how functools.partials seems to do it, too, to retain the return type
|
|
PartialT = TypeVar("PartialT")
|
|
|
|
|
|
def partial(
|
|
func: Callable[..., PartialT], *args: Any, **kwargs: Any
|
|
) -> Callable[..., PartialT]:
|
|
"""Wrapper around functools.partial that retains docstrings and can include
|
|
other workarounds if needed.
|
|
"""
|
|
partial_func = functools.partial(func, *args, **kwargs)
|
|
partial_func.__doc__ = func.__doc__
|
|
return partial_func
|
|
|
|
|
|
class DataValidationError(ValueError):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
X: Any,
|
|
Y: Any,
|
|
errors: Union[Sequence[Mapping[str, Any]], List[Dict[str, Any]]] = [],
|
|
) -> None:
|
|
"""Custom error for validating inputs / outputs at runtime."""
|
|
message = f"Data validation error in '{name}'"
|
|
type_info = f"X: {type(X)} Y: {type(Y)}"
|
|
data = []
|
|
for error in errors:
|
|
err_loc = " -> ".join([str(p) for p in error.get("loc", [])])
|
|
data.append((err_loc, error.get("msg")))
|
|
result = [message, type_info, table(data)]
|
|
ValueError.__init__(self, "\n\n" + "\n".join(result))
|
|
|
|
|
|
class _ArgModelConfig:
|
|
extra = "forbid"
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
def validate_fwd_input_output(
|
|
name: str, func: Callable[[Any, Any, bool], Any], X: Any, Y: Any
|
|
) -> None:
|
|
"""Validate the input and output of a forward function against the type
|
|
annotations, if available. Used in Model.initialize with the input and
|
|
output samples as they pass through the network.
|
|
"""
|
|
sig = inspect.signature(func)
|
|
empty = inspect.Signature.empty
|
|
params = list(sig.parameters.values())
|
|
if len(params) != 3:
|
|
bad_params = f"{len(params)} ({', '.join([p.name for p in params])})"
|
|
err = f"Invalid forward function. Expected 3 arguments (model, X , is_train), got {bad_params}"
|
|
raise DataValidationError(name, X, Y, [{"msg": err}])
|
|
annot_x = params[1].annotation
|
|
annot_y = sig.return_annotation
|
|
sig_args: Dict[str, Any] = {"__config__": _ArgModelConfig}
|
|
args = {}
|
|
if X is not None and annot_x != empty:
|
|
if isinstance(X, list) and len(X) > 5:
|
|
X = X[:5]
|
|
sig_args["X"] = (annot_x, ...)
|
|
args["X"] = X
|
|
if Y is not None and annot_y != empty:
|
|
if isinstance(Y, list) and len(Y) > 5:
|
|
Y = Y[:5]
|
|
sig_args["Y"] = (annot_y, ...)
|
|
args["Y"] = (Y, lambda x: x)
|
|
ArgModel = create_model("ArgModel", **sig_args)
|
|
# Make sure the forward refs are resolved and the types used by them are
|
|
# available in the correct scope. See #494 for details.
|
|
ArgModel.update_forward_refs(**types.__dict__)
|
|
try:
|
|
ArgModel.parse_obj(args)
|
|
except ValidationError as e:
|
|
raise DataValidationError(name, X, Y, e.errors()) from None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def make_tempfile(mode="r"):
|
|
f = tempfile.NamedTemporaryFile(mode=mode, delete=False)
|
|
yield f
|
|
f.close()
|
|
os.remove(f.name)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def data_validation(validation):
|
|
with threading.Lock():
|
|
prev = DATA_VALIDATION.get()
|
|
DATA_VALIDATION.set(validation)
|
|
yield
|
|
DATA_VALIDATION.set(prev)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def use_nvtx_range(message: str, id_color: int = -1):
|
|
"""Context manager to register the executed code as an NVTX range. The
|
|
ranges can be used as markers in CUDA profiling."""
|
|
if has_cupy:
|
|
cupy.cuda.nvtx.RangePush(message, id_color)
|
|
yield
|
|
cupy.cuda.nvtx.RangePop()
|
|
else:
|
|
yield
|
|
|
|
|
|
@dataclass
|
|
class ArrayInfo:
|
|
"""Container for info for checking array compatibility."""
|
|
|
|
shape: types.Shape
|
|
dtype: types.DTypes
|
|
|
|
@classmethod
|
|
def from_array(cls, arr: ArrayXd):
|
|
return cls(shape=arr.shape, dtype=arr.dtype)
|
|
|
|
def check_consistency(self, arr: ArrayXd):
|
|
if arr.shape != self.shape:
|
|
raise ValueError(
|
|
f"Shape mismatch in backprop. Y: {self.shape}, dY: {arr.shape}"
|
|
)
|
|
if arr.dtype != self.dtype:
|
|
raise ValueError(
|
|
f"Type mismatch in backprop. Y: {self.dtype}, dY: {arr.dtype}"
|
|
)
|
|
|
|
|
|
# fmt: off
|
|
__all__ = [
|
|
"get_array_module",
|
|
"get_torch_default_device",
|
|
"fix_random_seed",
|
|
"is_cupy_array",
|
|
"is_numpy_array",
|
|
"set_active_gpu",
|
|
"prefer_gpu",
|
|
"require_gpu",
|
|
"copy_array",
|
|
"to_categorical",
|
|
"get_width",
|
|
"xp2torch",
|
|
"torch2xp",
|
|
"tensorflow2xp",
|
|
"xp2tensorflow",
|
|
"validate_fwd_input_output",
|
|
"DataValidationError",
|
|
"make_tempfile",
|
|
"use_nvtx_range",
|
|
"ArrayInfo",
|
|
"has_cupy",
|
|
"has_torch",
|
|
]
|
|
# fmt: on
|