# mypy: ignore-errors from __future__ import annotations import builtins import math import operator from typing import Sequence import torch from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util from ._normalizations import ( ArrayLike, normalize_array_like, normalizer, NotImplementedType, ) newaxis = None FLAGS = [ "C_CONTIGUOUS", "F_CONTIGUOUS", "OWNDATA", "WRITEABLE", "ALIGNED", "WRITEBACKIFCOPY", "FNC", "FORC", "BEHAVED", "CARRAY", "FARRAY", ] SHORTHAND_TO_FLAGS = { "C": "C_CONTIGUOUS", "F": "F_CONTIGUOUS", "O": "OWNDATA", "W": "WRITEABLE", "A": "ALIGNED", "X": "WRITEBACKIFCOPY", "B": "BEHAVED", "CA": "CARRAY", "FA": "FARRAY", } class Flags: def __init__(self, flag_to_value: dict): assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check self._flag_to_value = flag_to_value def __getattr__(self, attr: str): if attr.islower() and attr.upper() in FLAGS: return self[attr.upper()] else: raise AttributeError(f"No flag attribute '{attr}'") def __getitem__(self, key): if key in SHORTHAND_TO_FLAGS.keys(): key = SHORTHAND_TO_FLAGS[key] if key in FLAGS: try: return self._flag_to_value[key] except KeyError as e: raise NotImplementedError(f"{key=}") from e else: raise KeyError(f"No flag key '{key}'") def __setattr__(self, attr, value): if attr.islower() and attr.upper() in FLAGS: self[attr.upper()] = value else: super().__setattr__(attr, value) def __setitem__(self, key, value): if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): raise NotImplementedError("Modifying flags is not implemented") else: raise KeyError(f"No flag key '{key}'") def create_method(fn, name=None): name = name or fn.__name__ def f(*args, **kwargs): return fn(*args, **kwargs) f.__name__ = name f.__qualname__ = f"ndarray.{name}" return f # Map ndarray.name_method -> np.name_func # If name_func == None, it means that name_method == name_func methods = { "clip": None, "nonzero": None, "repeat": None, "round": None, "squeeze": None, "swapaxes": None, "ravel": None, # linalg "diagonal": None, "dot": None, "trace": None, # sorting "argsort": None, "searchsorted": None, # reductions "argmax": None, "argmin": None, "any": None, "all": None, "max": None, "min": None, "ptp": None, "sum": None, "prod": None, "mean": None, "var": None, "std": None, # scans "cumsum": None, "cumprod": None, # advanced indexing "take": None, "choose": None, } dunder = { "abs": "absolute", "invert": None, "pos": "positive", "neg": "negative", "gt": "greater", "lt": "less", "ge": "greater_equal", "le": "less_equal", } # dunder methods with right-looking and in-place variants ri_dunder = { "add": None, "sub": "subtract", "mul": "multiply", "truediv": "divide", "floordiv": "floor_divide", "pow": "power", "mod": "remainder", "and": "bitwise_and", "or": "bitwise_or", "xor": "bitwise_xor", "lshift": "left_shift", "rshift": "right_shift", "matmul": None, } def _upcast_int_indices(index): if isinstance(index, torch.Tensor): if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8): return index.to(torch.int64) elif isinstance(index, tuple): return tuple(_upcast_int_indices(i) for i in index) return index # Used to indicate that a parameter is unspecified (as opposed to explicitly # `None`) class _Unspecified: pass _Unspecified.unspecified = _Unspecified() ############################################################### # ndarray class # ############################################################### class ndarray: def __init__(self, t=None): if t is None: self.tensor = torch.Tensor() elif isinstance(t, torch.Tensor): self.tensor = t else: raise ValueError( "ndarray constructor is not recommended; prefer" "either array(...) or zeros/empty(...)" ) # Register NumPy functions as methods for method, name in methods.items(): fn = getattr(_funcs, name or method) vars()[method] = create_method(fn, method) # Regular methods but coming from ufuncs conj = create_method(_ufuncs.conjugate, "conj") conjugate = create_method(_ufuncs.conjugate) for method, name in dunder.items(): fn = getattr(_ufuncs, name or method) method = f"__{method}__" vars()[method] = create_method(fn, method) for method, name in ri_dunder.items(): fn = getattr(_ufuncs, name or method) plain = f"__{method}__" vars()[plain] = create_method(fn, plain) rvar = f"__r{method}__" vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar) ivar = f"__i{method}__" vars()[ivar] = create_method( lambda self, other, fn=fn: fn(self, other, out=self), ivar ) # There's no __idivmod__ __divmod__ = create_method(_ufuncs.divmod, "__divmod__") __rdivmod__ = create_method( lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__" ) # prevent loop variables leaking into the ndarray class namespace del ivar, rvar, name, plain, fn, method @property def shape(self): return tuple(self.tensor.shape) @property def size(self): return self.tensor.numel() @property def ndim(self): return self.tensor.ndim @property def dtype(self): return _dtypes.dtype(self.tensor.dtype) @property def strides(self): elsize = self.tensor.element_size() return tuple(stride * elsize for stride in self.tensor.stride()) @property def itemsize(self): return self.tensor.element_size() @property def flags(self): # Note contiguous in torch is assumed C-style return Flags( { "C_CONTIGUOUS": self.tensor.is_contiguous(), "F_CONTIGUOUS": self.T.tensor.is_contiguous(), "OWNDATA": self.tensor._base is None, "WRITEABLE": True, # pytorch does not have readonly tensors } ) @property def data(self): return self.tensor.data_ptr() @property def nbytes(self): return self.tensor.storage().nbytes() @property def T(self): return self.transpose() @property def real(self): return _funcs.real(self) @real.setter def real(self, value): self.tensor.real = asarray(value).tensor @property def imag(self): return _funcs.imag(self) @imag.setter def imag(self, value): self.tensor.imag = asarray(value).tensor # ctors def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): if order != "K": raise NotImplementedError(f"astype(..., order={order} is not implemented.") if casting != "unsafe": raise NotImplementedError( f"astype(..., casting={casting} is not implemented." ) if not subok: raise NotImplementedError(f"astype(..., subok={subok} is not implemented.") if not copy: raise NotImplementedError(f"astype(..., copy={copy} is not implemented.") torch_dtype = _dtypes.dtype(dtype).torch_dtype t = self.tensor.to(torch_dtype) return ndarray(t) @normalizer def copy(self: ArrayLike, order: NotImplementedType = "C"): return self.clone() @normalizer def flatten(self: ArrayLike, order: NotImplementedType = "C"): return torch.flatten(self) def resize(self, *new_shape, refcheck=False): # NB: differs from np.resize: fills with zeros instead of making repeated copies of input. if refcheck: raise NotImplementedError( f"resize(..., refcheck={refcheck} is not implemented." ) if new_shape in [(), (None,)]: return # support both x.resize((2, 2)) and x.resize(2, 2) if len(new_shape) == 1: new_shape = new_shape[0] if isinstance(new_shape, int): new_shape = (new_shape,) if builtins.any(x < 0 for x in new_shape): raise ValueError("all elements of `new_shape` must be non-negative") new_numel, old_numel = math.prod(new_shape), self.tensor.numel() self.tensor.resize_(new_shape) if new_numel >= old_numel: # zero-fill new elements assert self.tensor.is_contiguous() b = self.tensor.flatten() # does not copy b[old_numel:].zero_() def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified): if dtype is _Unspecified.unspecified: dtype = self.dtype if type is not _Unspecified.unspecified: raise NotImplementedError(f"view(..., type={type} is not implemented.") torch_dtype = _dtypes.dtype(dtype).torch_dtype tview = self.tensor.view(torch_dtype) return ndarray(tview) @normalizer def fill(self, value: ArrayLike): # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and # error out on D > 0 arrays self.tensor.fill_(value) def tolist(self): return self.tensor.tolist() def __iter__(self): return (ndarray(x) for x in self.tensor.__iter__()) def __str__(self): return ( str(self.tensor) .replace("tensor", "torch.ndarray") .replace("dtype=torch.", "dtype=") ) __repr__ = create_method(__str__) def __eq__(self, other): try: return _ufuncs.equal(self, other) except (RuntimeError, TypeError): # Failed to convert other to array: definitely not equal. falsy = torch.full(self.shape, fill_value=False, dtype=bool) return asarray(falsy) def __ne__(self, other): return ~(self == other) def __index__(self): try: return operator.index(self.tensor.item()) except Exception as exc: raise TypeError( "only integer scalar arrays can be converted to a scalar index" ) from exc def __bool__(self): return bool(self.tensor) def __int__(self): return int(self.tensor) def __float__(self): return float(self.tensor) def __complex__(self): return complex(self.tensor) def is_integer(self): try: v = self.tensor.item() result = int(v) == v except Exception: result = False return result def __len__(self): return self.tensor.shape[0] def __contains__(self, x): return self.tensor.__contains__(x) def transpose(self, *axes): # np.transpose(arr, axis=None) but arr.transpose(*axes) return _funcs.transpose(self, axes) def reshape(self, *shape, order="C"): # arr.reshape(shape) and arr.reshape(*shape) return _funcs.reshape(self, shape, order=order) def sort(self, axis=-1, kind=None, order=None): # ndarray.sort works in-place _funcs.copyto(self, _funcs.sort(self, axis, kind, order)) def item(self, *args): # Mimic NumPy's implementation with three special cases (no arguments, # a flat index and a multi-index): # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702 if args == (): return self.tensor.item() elif len(args) == 1: # int argument return self.ravel()[args[0]] else: return self.__getitem__(args) def __getitem__(self, index): tensor = self.tensor def neg_step(i, s): if not (isinstance(s, slice) and s.step is not None and s.step < 0): return s nonlocal tensor tensor = torch.flip(tensor, (i,)) # Account for the fact that a slice includes the start but not the end assert isinstance(s.start, int) or s.start is None assert isinstance(s.stop, int) or s.stop is None start = s.stop + 1 if s.stop else None stop = s.start + 1 if s.start else None return slice(start, stop, -s.step) if isinstance(index, Sequence): index = type(index)(neg_step(i, s) for i, s in enumerate(index)) else: index = neg_step(0, index) index = _util.ndarrays_to_tensors(index) index = _upcast_int_indices(index) return ndarray(tensor.__getitem__(index)) def __setitem__(self, index, value): index = _util.ndarrays_to_tensors(index) index = _upcast_int_indices(index) if not _dtypes_impl.is_scalar(value): value = normalize_array_like(value) value = _util.cast_if_needed(value, self.tensor.dtype) return self.tensor.__setitem__(index, value) take = _funcs.take put = _funcs.put def __dlpack__(self, *, stream=None): return self.tensor.__dlpack__(stream=stream) def __dlpack_device__(self): return self.tensor.__dlpack_device__() def _tolist(obj): """Recursively convert tensors into lists.""" a1 = [] for elem in obj: if isinstance(elem, (list, tuple)): elem = _tolist(elem) if isinstance(elem, ndarray): a1.append(elem.tensor.tolist()) else: a1.append(elem) return a1 # This is the ideally the only place which talks to ndarray directly. # The rest goes through asarray (preferred) or array. def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None): if subok is not False: raise NotImplementedError("'subok' parameter is not supported.") if like is not None: raise NotImplementedError("'like' parameter is not supported.") if order != "K": raise NotImplementedError() # a happy path if ( isinstance(obj, ndarray) and copy is False and dtype is None and ndmin <= obj.ndim ): return obj if isinstance(obj, (list, tuple)): # FIXME and they have the same dtype, device, etc if obj and all(isinstance(x, torch.Tensor) for x in obj): # list of arrays: *under torch.Dynamo* these are FakeTensors obj = torch.stack(obj) else: # XXX: remove tolist # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists obj = _tolist(obj) # is obj an ndarray already? if isinstance(obj, ndarray): obj = obj.tensor # is a specific dtype requested? torch_dtype = None if dtype is not None: torch_dtype = _dtypes.dtype(dtype).torch_dtype tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin) return ndarray(tensor) def asarray(a, dtype=None, order="K", *, like=None): return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0) def ascontiguousarray(a, dtype=None, *, like=None): arr = asarray(a, dtype=dtype, like=like) if not arr.tensor.is_contiguous(): arr.tensor = arr.tensor.contiguous() return arr def from_dlpack(x, /): t = torch.from_dlpack(x) return ndarray(t) def _extract_dtype(entry): try: dty = _dtypes.dtype(entry) except Exception: dty = asarray(entry).dtype return dty def can_cast(from_, to, casting="safe"): from_ = _extract_dtype(from_) to_ = _extract_dtype(to) return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) def result_type(*arrays_and_dtypes): tensors = [] for entry in arrays_and_dtypes: try: t = asarray(entry).tensor except (RuntimeError, ValueError, TypeError): dty = _dtypes.dtype(entry) t = torch.empty(1, dtype=dty.torch_dtype) tensors.append(t) torch_dtype = _dtypes_impl.result_type_impl(*tensors) return _dtypes.dtype(torch_dtype)