262 lines
7.4 KiB
Python
262 lines
7.4 KiB
Python
|
# mypy: ignore-errors
|
||
|
|
||
|
"""Assorted utilities, which do not need anything other then torch and stdlib.
|
||
|
"""
|
||
|
|
||
|
import operator
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from . import _dtypes_impl
|
||
|
|
||
|
|
||
|
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
|
||
|
def is_sequence(seq):
|
||
|
if isinstance(seq, str):
|
||
|
return False
|
||
|
try:
|
||
|
len(seq)
|
||
|
except Exception:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
class AxisError(ValueError, IndexError):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class UFuncTypeError(TypeError, RuntimeError):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def cast_if_needed(tensor, dtype):
|
||
|
# NB: no casting if dtype=None
|
||
|
if dtype is not None and tensor.dtype != dtype:
|
||
|
tensor = tensor.to(dtype)
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def cast_int_to_float(x):
|
||
|
# cast integers and bools to the default float dtype
|
||
|
if _dtypes_impl._category(x.dtype) < 2:
|
||
|
x = x.to(_dtypes_impl.default_dtypes().float_dtype)
|
||
|
return x
|
||
|
|
||
|
|
||
|
# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h
|
||
|
def normalize_axis_index(ax, ndim, argname=None):
|
||
|
if not (-ndim <= ax < ndim):
|
||
|
raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}")
|
||
|
if ax < 0:
|
||
|
ax += ndim
|
||
|
return ax
|
||
|
|
||
|
|
||
|
# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378
|
||
|
def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
|
||
|
"""
|
||
|
Normalizes an axis argument into a tuple of non-negative integer axes.
|
||
|
|
||
|
This handles shorthands such as ``1`` and converts them to ``(1,)``,
|
||
|
as well as performing the handling of negative indices covered by
|
||
|
`normalize_axis_index`.
|
||
|
|
||
|
By default, this forbids axes from being specified multiple times.
|
||
|
Used internally by multi-axis-checking logic.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
axis : int, iterable of int
|
||
|
The un-normalized index or indices of the axis.
|
||
|
ndim : int
|
||
|
The number of dimensions of the array that `axis` should be normalized
|
||
|
against.
|
||
|
argname : str, optional
|
||
|
A prefix to put before the error message, typically the name of the
|
||
|
argument.
|
||
|
allow_duplicate : bool, optional
|
||
|
If False, the default, disallow an axis from being specified twice.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
normalized_axes : tuple of int
|
||
|
The normalized axis index, such that `0 <= normalized_axis < ndim`
|
||
|
"""
|
||
|
# Optimization to speed-up the most common cases.
|
||
|
if type(axis) not in (tuple, list):
|
||
|
try:
|
||
|
axis = [operator.index(axis)]
|
||
|
except TypeError:
|
||
|
pass
|
||
|
# Going via an iterator directly is slower than via list comprehension.
|
||
|
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
|
||
|
if not allow_duplicate and len(set(axis)) != len(axis):
|
||
|
if argname:
|
||
|
raise ValueError(f"repeated axis in `{argname}` argument")
|
||
|
else:
|
||
|
raise ValueError("repeated axis")
|
||
|
return axis
|
||
|
|
||
|
|
||
|
def allow_only_single_axis(axis):
|
||
|
if axis is None:
|
||
|
return axis
|
||
|
if len(axis) != 1:
|
||
|
raise NotImplementedError("does not handle tuple axis")
|
||
|
return axis[0]
|
||
|
|
||
|
|
||
|
def expand_shape(arr_shape, axis):
|
||
|
# taken from numpy 1.23.x, expand_dims function
|
||
|
if type(axis) not in (list, tuple):
|
||
|
axis = (axis,)
|
||
|
out_ndim = len(axis) + len(arr_shape)
|
||
|
axis = normalize_axis_tuple(axis, out_ndim)
|
||
|
shape_it = iter(arr_shape)
|
||
|
shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
|
||
|
return shape
|
||
|
|
||
|
|
||
|
def apply_keepdims(tensor, axis, ndim):
|
||
|
if axis is None:
|
||
|
# tensor was a scalar
|
||
|
shape = (1,) * ndim
|
||
|
tensor = tensor.expand(shape).contiguous()
|
||
|
else:
|
||
|
shape = expand_shape(tensor.shape, axis)
|
||
|
tensor = tensor.reshape(shape)
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def axis_none_flatten(*tensors, axis=None):
|
||
|
"""Flatten the arrays if axis is None."""
|
||
|
if axis is None:
|
||
|
tensors = tuple(ar.flatten() for ar in tensors)
|
||
|
return tensors, 0
|
||
|
else:
|
||
|
return tensors, axis
|
||
|
|
||
|
|
||
|
def typecast_tensor(t, target_dtype, casting):
|
||
|
"""Dtype-cast tensor to target_dtype.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
t : torch.Tensor
|
||
|
The tensor to cast
|
||
|
target_dtype : torch dtype object
|
||
|
The array dtype to cast all tensors to
|
||
|
casting : str
|
||
|
The casting mode, see `np.can_cast`
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
`torch.Tensor` of the `target_dtype` dtype
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
ValueError
|
||
|
if the argument cannot be cast according to the `casting` rule
|
||
|
|
||
|
"""
|
||
|
can_cast = _dtypes_impl.can_cast_impl
|
||
|
|
||
|
if not can_cast(t.dtype, target_dtype, casting=casting):
|
||
|
raise TypeError(
|
||
|
f"Cannot cast array data from {t.dtype} to"
|
||
|
f" {target_dtype} according to the rule '{casting}'"
|
||
|
)
|
||
|
return cast_if_needed(t, target_dtype)
|
||
|
|
||
|
|
||
|
def typecast_tensors(tensors, target_dtype, casting):
|
||
|
return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)
|
||
|
|
||
|
|
||
|
def _try_convert_to_tensor(obj):
|
||
|
try:
|
||
|
tensor = torch.as_tensor(obj)
|
||
|
except Exception as e:
|
||
|
mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}."
|
||
|
raise NotImplementedError(mesg) # noqa: TRY200
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
|
||
|
"""The core logic of the array(...) function.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
obj : tensor_like
|
||
|
The thing to coerce
|
||
|
dtype : torch.dtype object or None
|
||
|
Coerce to this torch dtype
|
||
|
copy : bool
|
||
|
Copy or not
|
||
|
ndmin : int
|
||
|
The results as least this many dimensions
|
||
|
is_weak : bool
|
||
|
Whether obj is a weakly typed python scalar.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
tensor : torch.Tensor
|
||
|
a tensor object with requested dtype, ndim and copy semantics.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
This is almost a "tensor_like" coersion function. Does not handle wrapper
|
||
|
ndarrays (those should be handled in the ndarray-aware layer prior to
|
||
|
invoking this function).
|
||
|
"""
|
||
|
if isinstance(obj, torch.Tensor):
|
||
|
tensor = obj
|
||
|
else:
|
||
|
# tensor.dtype is the pytorch default, typically float32. If obj's elements
|
||
|
# are not exactly representable in float32, we've lost precision:
|
||
|
# >>> torch.as_tensor(1e12).item() - 1e12
|
||
|
# -4096.0
|
||
|
default_dtype = torch.get_default_dtype()
|
||
|
torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32))
|
||
|
try:
|
||
|
tensor = _try_convert_to_tensor(obj)
|
||
|
finally:
|
||
|
torch.set_default_dtype(default_dtype)
|
||
|
|
||
|
# type cast if requested
|
||
|
tensor = cast_if_needed(tensor, dtype)
|
||
|
|
||
|
# adjust ndim if needed
|
||
|
ndim_extra = ndmin - tensor.ndim
|
||
|
if ndim_extra > 0:
|
||
|
tensor = tensor.view((1,) * ndim_extra + tensor.shape)
|
||
|
|
||
|
# copy if requested
|
||
|
if copy:
|
||
|
tensor = tensor.clone()
|
||
|
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def ndarrays_to_tensors(*inputs):
|
||
|
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
|
||
|
from ._ndarray import ndarray
|
||
|
|
||
|
if len(inputs) == 0:
|
||
|
return ValueError()
|
||
|
elif len(inputs) == 1:
|
||
|
input_ = inputs[0]
|
||
|
if isinstance(input_, ndarray):
|
||
|
return input_.tensor
|
||
|
elif isinstance(input_, tuple):
|
||
|
result = []
|
||
|
for sub_input in input_:
|
||
|
sub_result = ndarrays_to_tensors(sub_input)
|
||
|
result.append(sub_result)
|
||
|
return tuple(result)
|
||
|
else:
|
||
|
return input_
|
||
|
else:
|
||
|
assert isinstance(inputs, tuple) # sanity check
|
||
|
return ndarrays_to_tensors(inputs)
|