# mypy: ignore-errors """Assorted utilities, which do not need anything other then torch and stdlib. """ import operator import torch from . import _dtypes_impl # https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 def is_sequence(seq): if isinstance(seq, str): return False try: len(seq) except Exception: return False return True class AxisError(ValueError, IndexError): pass class UFuncTypeError(TypeError, RuntimeError): pass def cast_if_needed(tensor, dtype): # NB: no casting if dtype=None if dtype is not None and tensor.dtype != dtype: tensor = tensor.to(dtype) return tensor def cast_int_to_float(x): # cast integers and bools to the default float dtype if _dtypes_impl._category(x.dtype) < 2: x = x.to(_dtypes_impl.default_dtypes().float_dtype) return x # a replica of the version in ./numpy/numpy/core/src/multiarray/common.h def normalize_axis_index(ax, ndim, argname=None): if not (-ndim <= ax < ndim): raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") if ax < 0: ax += ndim return ax # from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378 def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): """ Normalizes an axis argument into a tuple of non-negative integer axes. This handles shorthands such as ``1`` and converts them to ``(1,)``, as well as performing the handling of negative indices covered by `normalize_axis_index`. By default, this forbids axes from being specified multiple times. Used internally by multi-axis-checking logic. Parameters ---------- axis : int, iterable of int The un-normalized index or indices of the axis. ndim : int The number of dimensions of the array that `axis` should be normalized against. argname : str, optional A prefix to put before the error message, typically the name of the argument. allow_duplicate : bool, optional If False, the default, disallow an axis from being specified twice. Returns ------- normalized_axes : tuple of int The normalized axis index, such that `0 <= normalized_axis < ndim` """ # Optimization to speed-up the most common cases. if type(axis) not in (tuple, list): try: axis = [operator.index(axis)] except TypeError: pass # Going via an iterator directly is slower than via list comprehension. axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) if not allow_duplicate and len(set(axis)) != len(axis): if argname: raise ValueError(f"repeated axis in `{argname}` argument") else: raise ValueError("repeated axis") return axis def allow_only_single_axis(axis): if axis is None: return axis if len(axis) != 1: raise NotImplementedError("does not handle tuple axis") return axis[0] def expand_shape(arr_shape, axis): # taken from numpy 1.23.x, expand_dims function if type(axis) not in (list, tuple): axis = (axis,) out_ndim = len(axis) + len(arr_shape) axis = normalize_axis_tuple(axis, out_ndim) shape_it = iter(arr_shape) shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] return shape def apply_keepdims(tensor, axis, ndim): if axis is None: # tensor was a scalar shape = (1,) * ndim tensor = tensor.expand(shape).contiguous() else: shape = expand_shape(tensor.shape, axis) tensor = tensor.reshape(shape) return tensor def axis_none_flatten(*tensors, axis=None): """Flatten the arrays if axis is None.""" if axis is None: tensors = tuple(ar.flatten() for ar in tensors) return tensors, 0 else: return tensors, axis def typecast_tensor(t, target_dtype, casting): """Dtype-cast tensor to target_dtype. Parameters ---------- t : torch.Tensor The tensor to cast target_dtype : torch dtype object The array dtype to cast all tensors to casting : str The casting mode, see `np.can_cast` Returns ------- `torch.Tensor` of the `target_dtype` dtype Raises ------ ValueError if the argument cannot be cast according to the `casting` rule """ can_cast = _dtypes_impl.can_cast_impl if not can_cast(t.dtype, target_dtype, casting=casting): raise TypeError( f"Cannot cast array data from {t.dtype} to" f" {target_dtype} according to the rule '{casting}'" ) return cast_if_needed(t, target_dtype) def typecast_tensors(tensors, target_dtype, casting): return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) def _try_convert_to_tensor(obj): try: tensor = torch.as_tensor(obj) except Exception as e: mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." raise NotImplementedError(mesg) # noqa: TRY200 return tensor def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): """The core logic of the array(...) function. Parameters ---------- obj : tensor_like The thing to coerce dtype : torch.dtype object or None Coerce to this torch dtype copy : bool Copy or not ndmin : int The results as least this many dimensions is_weak : bool Whether obj is a weakly typed python scalar. Returns ------- tensor : torch.Tensor a tensor object with requested dtype, ndim and copy semantics. Notes ----- This is almost a "tensor_like" coersion function. Does not handle wrapper ndarrays (those should be handled in the ndarray-aware layer prior to invoking this function). """ if isinstance(obj, torch.Tensor): tensor = obj else: # tensor.dtype is the pytorch default, typically float32. If obj's elements # are not exactly representable in float32, we've lost precision: # >>> torch.as_tensor(1e12).item() - 1e12 # -4096.0 default_dtype = torch.get_default_dtype() torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32)) try: tensor = _try_convert_to_tensor(obj) finally: torch.set_default_dtype(default_dtype) # type cast if requested tensor = cast_if_needed(tensor, dtype) # adjust ndim if needed ndim_extra = ndmin - tensor.ndim if ndim_extra > 0: tensor = tensor.view((1,) * ndim_extra + tensor.shape) # copy if requested if copy: tensor = tensor.clone() return tensor def ndarrays_to_tensors(*inputs): """Convert all ndarrays from `inputs` to tensors. (other things are intact)""" from ._ndarray import ndarray if len(inputs) == 0: return ValueError() elif len(inputs) == 1: input_ = inputs[0] if isinstance(input_, ndarray): return input_.tensor elif isinstance(input_, tuple): result = [] for sub_input in input_: sub_result = ndarrays_to_tensors(sub_input) result.append(sub_result) return tuple(result) else: return input_ else: assert isinstance(inputs, tuple) # sanity check return ndarrays_to_tensors(inputs)