ai-content-maker/.venv/Lib/site-packages/torch/_numpy/_ndarray.py

592 lines
16 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
# mypy: ignore-errors
from __future__ import annotations
import builtins
import math
import operator
from typing import Sequence
import torch
from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
from ._normalizations import (
ArrayLike,
normalize_array_like,
normalizer,
NotImplementedType,
)
newaxis = None
FLAGS = [
"C_CONTIGUOUS",
"F_CONTIGUOUS",
"OWNDATA",
"WRITEABLE",
"ALIGNED",
"WRITEBACKIFCOPY",
"FNC",
"FORC",
"BEHAVED",
"CARRAY",
"FARRAY",
]
SHORTHAND_TO_FLAGS = {
"C": "C_CONTIGUOUS",
"F": "F_CONTIGUOUS",
"O": "OWNDATA",
"W": "WRITEABLE",
"A": "ALIGNED",
"X": "WRITEBACKIFCOPY",
"B": "BEHAVED",
"CA": "CARRAY",
"FA": "FARRAY",
}
class Flags:
def __init__(self, flag_to_value: dict):
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
self._flag_to_value = flag_to_value
def __getattr__(self, attr: str):
if attr.islower() and attr.upper() in FLAGS:
return self[attr.upper()]
else:
raise AttributeError(f"No flag attribute '{attr}'")
def __getitem__(self, key):
if key in SHORTHAND_TO_FLAGS.keys():
key = SHORTHAND_TO_FLAGS[key]
if key in FLAGS:
try:
return self._flag_to_value[key]
except KeyError as e:
raise NotImplementedError(f"{key=}") from e
else:
raise KeyError(f"No flag key '{key}'")
def __setattr__(self, attr, value):
if attr.islower() and attr.upper() in FLAGS:
self[attr.upper()] = value
else:
super().__setattr__(attr, value)
def __setitem__(self, key, value):
if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
raise NotImplementedError("Modifying flags is not implemented")
else:
raise KeyError(f"No flag key '{key}'")
def create_method(fn, name=None):
name = name or fn.__name__
def f(*args, **kwargs):
return fn(*args, **kwargs)
f.__name__ = name
f.__qualname__ = f"ndarray.{name}"
return f
# Map ndarray.name_method -> np.name_func
# If name_func == None, it means that name_method == name_func
methods = {
"clip": None,
"nonzero": None,
"repeat": None,
"round": None,
"squeeze": None,
"swapaxes": None,
"ravel": None,
# linalg
"diagonal": None,
"dot": None,
"trace": None,
# sorting
"argsort": None,
"searchsorted": None,
# reductions
"argmax": None,
"argmin": None,
"any": None,
"all": None,
"max": None,
"min": None,
"ptp": None,
"sum": None,
"prod": None,
"mean": None,
"var": None,
"std": None,
# scans
"cumsum": None,
"cumprod": None,
# advanced indexing
"take": None,
"choose": None,
}
dunder = {
"abs": "absolute",
"invert": None,
"pos": "positive",
"neg": "negative",
"gt": "greater",
"lt": "less",
"ge": "greater_equal",
"le": "less_equal",
}
# dunder methods with right-looking and in-place variants
ri_dunder = {
"add": None,
"sub": "subtract",
"mul": "multiply",
"truediv": "divide",
"floordiv": "floor_divide",
"pow": "power",
"mod": "remainder",
"and": "bitwise_and",
"or": "bitwise_or",
"xor": "bitwise_xor",
"lshift": "left_shift",
"rshift": "right_shift",
"matmul": None,
}
def _upcast_int_indices(index):
if isinstance(index, torch.Tensor):
if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
return index.to(torch.int64)
elif isinstance(index, tuple):
return tuple(_upcast_int_indices(i) for i in index)
return index
# Used to indicate that a parameter is unspecified (as opposed to explicitly
# `None`)
class _Unspecified:
pass
_Unspecified.unspecified = _Unspecified()
###############################################################
# ndarray class #
###############################################################
class ndarray:
def __init__(self, t=None):
if t is None:
self.tensor = torch.Tensor()
elif isinstance(t, torch.Tensor):
self.tensor = t
else:
raise ValueError(
"ndarray constructor is not recommended; prefer"
"either array(...) or zeros/empty(...)"
)
# Register NumPy functions as methods
for method, name in methods.items():
fn = getattr(_funcs, name or method)
vars()[method] = create_method(fn, method)
# Regular methods but coming from ufuncs
conj = create_method(_ufuncs.conjugate, "conj")
conjugate = create_method(_ufuncs.conjugate)
for method, name in dunder.items():
fn = getattr(_ufuncs, name or method)
method = f"__{method}__"
vars()[method] = create_method(fn, method)
for method, name in ri_dunder.items():
fn = getattr(_ufuncs, name or method)
plain = f"__{method}__"
vars()[plain] = create_method(fn, plain)
rvar = f"__r{method}__"
vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
ivar = f"__i{method}__"
vars()[ivar] = create_method(
lambda self, other, fn=fn: fn(self, other, out=self), ivar
)
# There's no __idivmod__
__divmod__ = create_method(_ufuncs.divmod, "__divmod__")
__rdivmod__ = create_method(
lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
)
# prevent loop variables leaking into the ndarray class namespace
del ivar, rvar, name, plain, fn, method
@property
def shape(self):
return tuple(self.tensor.shape)
@property
def size(self):
return self.tensor.numel()
@property
def ndim(self):
return self.tensor.ndim
@property
def dtype(self):
return _dtypes.dtype(self.tensor.dtype)
@property
def strides(self):
elsize = self.tensor.element_size()
return tuple(stride * elsize for stride in self.tensor.stride())
@property
def itemsize(self):
return self.tensor.element_size()
@property
def flags(self):
# Note contiguous in torch is assumed C-style
return Flags(
{
"C_CONTIGUOUS": self.tensor.is_contiguous(),
"F_CONTIGUOUS": self.T.tensor.is_contiguous(),
"OWNDATA": self.tensor._base is None,
"WRITEABLE": True, # pytorch does not have readonly tensors
}
)
@property
def data(self):
return self.tensor.data_ptr()
@property
def nbytes(self):
return self.tensor.storage().nbytes()
@property
def T(self):
return self.transpose()
@property
def real(self):
return _funcs.real(self)
@real.setter
def real(self, value):
self.tensor.real = asarray(value).tensor
@property
def imag(self):
return _funcs.imag(self)
@imag.setter
def imag(self, value):
self.tensor.imag = asarray(value).tensor
# ctors
def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
if order != "K":
raise NotImplementedError(f"astype(..., order={order} is not implemented.")
if casting != "unsafe":
raise NotImplementedError(
f"astype(..., casting={casting} is not implemented."
)
if not subok:
raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
if not copy:
raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
torch_dtype = _dtypes.dtype(dtype).torch_dtype
t = self.tensor.to(torch_dtype)
return ndarray(t)
@normalizer
def copy(self: ArrayLike, order: NotImplementedType = "C"):
return self.clone()
@normalizer
def flatten(self: ArrayLike, order: NotImplementedType = "C"):
return torch.flatten(self)
def resize(self, *new_shape, refcheck=False):
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
if refcheck:
raise NotImplementedError(
f"resize(..., refcheck={refcheck} is not implemented."
)
if new_shape in [(), (None,)]:
return
# support both x.resize((2, 2)) and x.resize(2, 2)
if len(new_shape) == 1:
new_shape = new_shape[0]
if isinstance(new_shape, int):
new_shape = (new_shape,)
if builtins.any(x < 0 for x in new_shape):
raise ValueError("all elements of `new_shape` must be non-negative")
new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
self.tensor.resize_(new_shape)
if new_numel >= old_numel:
# zero-fill new elements
assert self.tensor.is_contiguous()
b = self.tensor.flatten() # does not copy
b[old_numel:].zero_()
def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
if dtype is _Unspecified.unspecified:
dtype = self.dtype
if type is not _Unspecified.unspecified:
raise NotImplementedError(f"view(..., type={type} is not implemented.")
torch_dtype = _dtypes.dtype(dtype).torch_dtype
tview = self.tensor.view(torch_dtype)
return ndarray(tview)
@normalizer
def fill(self, value: ArrayLike):
# Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
# error out on D > 0 arrays
self.tensor.fill_(value)
def tolist(self):
return self.tensor.tolist()
def __iter__(self):
return (ndarray(x) for x in self.tensor.__iter__())
def __str__(self):
return (
str(self.tensor)
.replace("tensor", "torch.ndarray")
.replace("dtype=torch.", "dtype=")
)
__repr__ = create_method(__str__)
def __eq__(self, other):
try:
return _ufuncs.equal(self, other)
except (RuntimeError, TypeError):
# Failed to convert other to array: definitely not equal.
falsy = torch.full(self.shape, fill_value=False, dtype=bool)
return asarray(falsy)
def __ne__(self, other):
return ~(self == other)
def __index__(self):
try:
return operator.index(self.tensor.item())
except Exception as exc:
raise TypeError(
"only integer scalar arrays can be converted to a scalar index"
) from exc
def __bool__(self):
return bool(self.tensor)
def __int__(self):
return int(self.tensor)
def __float__(self):
return float(self.tensor)
def __complex__(self):
return complex(self.tensor)
def is_integer(self):
try:
v = self.tensor.item()
result = int(v) == v
except Exception:
result = False
return result
def __len__(self):
return self.tensor.shape[0]
def __contains__(self, x):
return self.tensor.__contains__(x)
def transpose(self, *axes):
# np.transpose(arr, axis=None) but arr.transpose(*axes)
return _funcs.transpose(self, axes)
def reshape(self, *shape, order="C"):
# arr.reshape(shape) and arr.reshape(*shape)
return _funcs.reshape(self, shape, order=order)
def sort(self, axis=-1, kind=None, order=None):
# ndarray.sort works in-place
_funcs.copyto(self, _funcs.sort(self, axis, kind, order))
def item(self, *args):
# Mimic NumPy's implementation with three special cases (no arguments,
# a flat index and a multi-index):
# https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702
if args == ():
return self.tensor.item()
elif len(args) == 1:
# int argument
return self.ravel()[args[0]]
else:
return self.__getitem__(args)
def __getitem__(self, index):
tensor = self.tensor
def neg_step(i, s):
if not (isinstance(s, slice) and s.step is not None and s.step < 0):
return s
nonlocal tensor
tensor = torch.flip(tensor, (i,))
# Account for the fact that a slice includes the start but not the end
assert isinstance(s.start, int) or s.start is None
assert isinstance(s.stop, int) or s.stop is None
start = s.stop + 1 if s.stop else None
stop = s.start + 1 if s.start else None
return slice(start, stop, -s.step)
if isinstance(index, Sequence):
index = type(index)(neg_step(i, s) for i, s in enumerate(index))
else:
index = neg_step(0, index)
index = _util.ndarrays_to_tensors(index)
index = _upcast_int_indices(index)
return ndarray(tensor.__getitem__(index))
def __setitem__(self, index, value):
index = _util.ndarrays_to_tensors(index)
index = _upcast_int_indices(index)
if not _dtypes_impl.is_scalar(value):
value = normalize_array_like(value)
value = _util.cast_if_needed(value, self.tensor.dtype)
return self.tensor.__setitem__(index, value)
take = _funcs.take
put = _funcs.put
def __dlpack__(self, *, stream=None):
return self.tensor.__dlpack__(stream=stream)
def __dlpack_device__(self):
return self.tensor.__dlpack_device__()
def _tolist(obj):
"""Recursively convert tensors into lists."""
a1 = []
for elem in obj:
if isinstance(elem, (list, tuple)):
elem = _tolist(elem)
if isinstance(elem, ndarray):
a1.append(elem.tensor.tolist())
else:
a1.append(elem)
return a1
# This is the ideally the only place which talks to ndarray directly.
# The rest goes through asarray (preferred) or array.
def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
if subok is not False:
raise NotImplementedError("'subok' parameter is not supported.")
if like is not None:
raise NotImplementedError("'like' parameter is not supported.")
if order != "K":
raise NotImplementedError()
# a happy path
if (
isinstance(obj, ndarray)
and copy is False
and dtype is None
and ndmin <= obj.ndim
):
return obj
if isinstance(obj, (list, tuple)):
# FIXME and they have the same dtype, device, etc
if obj and all(isinstance(x, torch.Tensor) for x in obj):
# list of arrays: *under torch.Dynamo* these are FakeTensors
obj = torch.stack(obj)
else:
# XXX: remove tolist
# lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
obj = _tolist(obj)
# is obj an ndarray already?
if isinstance(obj, ndarray):
obj = obj.tensor
# is a specific dtype requested?
torch_dtype = None
if dtype is not None:
torch_dtype = _dtypes.dtype(dtype).torch_dtype
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
return ndarray(tensor)
def asarray(a, dtype=None, order="K", *, like=None):
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
def ascontiguousarray(a, dtype=None, *, like=None):
arr = asarray(a, dtype=dtype, like=like)
if not arr.tensor.is_contiguous():
arr.tensor = arr.tensor.contiguous()
return arr
def from_dlpack(x, /):
t = torch.from_dlpack(x)
return ndarray(t)
def _extract_dtype(entry):
try:
dty = _dtypes.dtype(entry)
except Exception:
dty = asarray(entry).dtype
return dty
def can_cast(from_, to, casting="safe"):
from_ = _extract_dtype(from_)
to_ = _extract_dtype(to)
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
def result_type(*arrays_and_dtypes):
tensors = []
for entry in arrays_and_dtypes:
try:
t = asarray(entry).tensor
except (RuntimeError, ValueError, TypeError):
dty = _dtypes.dtype(entry)
t = torch.empty(1, dtype=dty.torch_dtype)
tensors.append(t)
torch_dtype = _dtypes_impl.result_type_impl(*tensors)
return _dtypes.dtype(torch_dtype)