357 lines
12 KiB
Python
357 lines
12 KiB
Python
|
"""Utility functions to use Python Array API compatible libraries.
|
||
|
|
||
|
For the context about the Array API see:
|
||
|
https://data-apis.org/array-api/latest/purpose_and_scope.html
|
||
|
|
||
|
The SciPy use case of the Array API is described on the following page:
|
||
|
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
|
||
|
"""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import os
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from scipy._lib import array_api_compat
|
||
|
from scipy._lib.array_api_compat import (
|
||
|
is_array_api_obj,
|
||
|
size,
|
||
|
numpy as np_compat,
|
||
|
)
|
||
|
|
||
|
__all__ = ['array_namespace', '_asarray', 'size']
|
||
|
|
||
|
|
||
|
# To enable array API and strict array-like input validation
|
||
|
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
|
||
|
# To control the default device - for use in the test suite only
|
||
|
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
|
||
|
|
||
|
_GLOBAL_CONFIG = {
|
||
|
"SCIPY_ARRAY_API": SCIPY_ARRAY_API,
|
||
|
"SCIPY_DEVICE": SCIPY_DEVICE,
|
||
|
}
|
||
|
|
||
|
|
||
|
def compliance_scipy(arrays):
|
||
|
"""Raise exceptions on known-bad subclasses.
|
||
|
|
||
|
The following subclasses are not supported and raise and error:
|
||
|
- `numpy.ma.MaskedArray`
|
||
|
- `numpy.matrix`
|
||
|
- NumPy arrays which do not have a boolean or numerical dtype
|
||
|
- Any array-like which is neither array API compatible nor coercible by NumPy
|
||
|
- Any array-like which is coerced by NumPy to an unsupported dtype
|
||
|
"""
|
||
|
for i in range(len(arrays)):
|
||
|
array = arrays[i]
|
||
|
if isinstance(array, np.ma.MaskedArray):
|
||
|
raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.")
|
||
|
elif isinstance(array, np.matrix):
|
||
|
raise TypeError("Inputs of type `numpy.matrix` are not supported.")
|
||
|
if isinstance(array, (np.ndarray, np.generic)):
|
||
|
dtype = array.dtype
|
||
|
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
|
||
|
raise TypeError(f"An argument has dtype `{dtype!r}`; "
|
||
|
f"only boolean and numerical dtypes are supported.")
|
||
|
elif not is_array_api_obj(array):
|
||
|
try:
|
||
|
array = np.asanyarray(array)
|
||
|
except TypeError:
|
||
|
raise TypeError("An argument is neither array API compatible nor "
|
||
|
"coercible by NumPy.")
|
||
|
dtype = array.dtype
|
||
|
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
|
||
|
message = (
|
||
|
f"An argument was coerced to an unsupported dtype `{dtype!r}`; "
|
||
|
f"only boolean and numerical dtypes are supported."
|
||
|
)
|
||
|
raise TypeError(message)
|
||
|
arrays[i] = array
|
||
|
return arrays
|
||
|
|
||
|
|
||
|
def _check_finite(array, xp):
|
||
|
"""Check for NaNs or Infs."""
|
||
|
msg = "array must not contain infs or NaNs"
|
||
|
try:
|
||
|
if not xp.all(xp.isfinite(array)):
|
||
|
raise ValueError(msg)
|
||
|
except TypeError:
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
|
||
|
def array_namespace(*arrays):
|
||
|
"""Get the array API compatible namespace for the arrays xs.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
*arrays : sequence of array_like
|
||
|
Arrays used to infer the common namespace.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
namespace : module
|
||
|
Common namespace.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
Thin wrapper around `array_api_compat.array_namespace`.
|
||
|
|
||
|
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
|
||
|
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
|
||
|
2. `compliance_scipy` raise exceptions on known-bad subclasses. See
|
||
|
its definition for more details.
|
||
|
|
||
|
When the global switch is False, it defaults to the `numpy` namespace.
|
||
|
In that case, there is no compliance check. This is a convenience to
|
||
|
ease the adoption. Otherwise, arrays must comply with the new rules.
|
||
|
"""
|
||
|
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]:
|
||
|
# here we could wrap the namespace if needed
|
||
|
return np_compat
|
||
|
|
||
|
arrays = [array for array in arrays if array is not None]
|
||
|
|
||
|
arrays = compliance_scipy(arrays)
|
||
|
|
||
|
return array_api_compat.array_namespace(*arrays)
|
||
|
|
||
|
|
||
|
def _asarray(
|
||
|
array, dtype=None, order=None, copy=None, *, xp=None, check_finite=False
|
||
|
):
|
||
|
"""SciPy-specific replacement for `np.asarray` with `order` and `check_finite`.
|
||
|
|
||
|
Memory layout parameter `order` is not exposed in the Array API standard.
|
||
|
`order` is only enforced if the input array implementation
|
||
|
is NumPy based, otherwise `order` is just silently ignored.
|
||
|
|
||
|
`check_finite` is also not a keyword in the array API standard; included
|
||
|
here for convenience rather than that having to be a separate function
|
||
|
call inside SciPy functions.
|
||
|
"""
|
||
|
if xp is None:
|
||
|
xp = array_namespace(array)
|
||
|
if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.numpy"}:
|
||
|
# Use NumPy API to support order
|
||
|
if copy is True:
|
||
|
array = np.array(array, order=order, dtype=dtype)
|
||
|
else:
|
||
|
array = np.asarray(array, order=order, dtype=dtype)
|
||
|
|
||
|
# At this point array is a NumPy ndarray. We convert it to an array
|
||
|
# container that is consistent with the input's namespace.
|
||
|
array = xp.asarray(array)
|
||
|
else:
|
||
|
try:
|
||
|
array = xp.asarray(array, dtype=dtype, copy=copy)
|
||
|
except TypeError:
|
||
|
coerced_xp = array_namespace(xp.asarray(3))
|
||
|
array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
|
||
|
|
||
|
if check_finite:
|
||
|
_check_finite(array, xp)
|
||
|
|
||
|
return array
|
||
|
|
||
|
|
||
|
def atleast_nd(x, *, ndim, xp=None):
|
||
|
"""Recursively expand the dimension to have at least `ndim`."""
|
||
|
if xp is None:
|
||
|
xp = array_namespace(x)
|
||
|
x = xp.asarray(x)
|
||
|
if x.ndim < ndim:
|
||
|
x = xp.expand_dims(x, axis=0)
|
||
|
x = atleast_nd(x, ndim=ndim, xp=xp)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def copy(x, *, xp=None):
|
||
|
"""
|
||
|
Copies an array.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
x : array
|
||
|
|
||
|
xp : array_namespace
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
copy : array
|
||
|
Copied array
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
This copy function does not offer all the semantics of `np.copy`, i.e. the
|
||
|
`subok` and `order` keywords are not used.
|
||
|
"""
|
||
|
# Note: xp.asarray fails if xp is numpy.
|
||
|
if xp is None:
|
||
|
xp = array_namespace(x)
|
||
|
|
||
|
return _asarray(x, copy=True, xp=xp)
|
||
|
|
||
|
|
||
|
def is_numpy(xp):
|
||
|
return xp.__name__ in ('numpy', 'scipy._lib.array_api_compat.numpy')
|
||
|
|
||
|
|
||
|
def is_cupy(xp):
|
||
|
return xp.__name__ in ('cupy', 'scipy._lib.array_api_compat.cupy')
|
||
|
|
||
|
|
||
|
def is_torch(xp):
|
||
|
return xp.__name__ in ('torch', 'scipy._lib.array_api_compat.torch')
|
||
|
|
||
|
|
||
|
def _strict_check(actual, desired, xp,
|
||
|
check_namespace=True, check_dtype=True, check_shape=True):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
if check_namespace:
|
||
|
_assert_matching_namespace(actual, desired)
|
||
|
|
||
|
desired = xp.asarray(desired)
|
||
|
|
||
|
if check_dtype:
|
||
|
_msg = "dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}"
|
||
|
assert actual.dtype == desired.dtype, _msg
|
||
|
|
||
|
if check_shape:
|
||
|
_msg = "Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
|
||
|
assert actual.shape == desired.shape, _msg
|
||
|
_check_scalar(actual, desired, xp)
|
||
|
|
||
|
desired = xp.broadcast_to(desired, actual.shape)
|
||
|
return desired
|
||
|
|
||
|
|
||
|
def _assert_matching_namespace(actual, desired):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
actual = actual if isinstance(actual, tuple) else (actual,)
|
||
|
desired_space = array_namespace(desired)
|
||
|
for arr in actual:
|
||
|
arr_space = array_namespace(arr)
|
||
|
_msg = (f"Namespaces do not match.\n"
|
||
|
f"Actual: {arr_space.__name__}\n"
|
||
|
f"Desired: {desired_space.__name__}")
|
||
|
assert arr_space == desired_space, _msg
|
||
|
|
||
|
|
||
|
def _check_scalar(actual, desired, xp):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
# Shape check alone is sufficient unless desired.shape == (). Also,
|
||
|
# only NumPy distinguishes between scalars and arrays.
|
||
|
if desired.shape != () or not is_numpy(xp):
|
||
|
return
|
||
|
# We want to follow the conventions of the `xp` library. Libraries like
|
||
|
# NumPy, for which `np.asarray(0)[()]` returns a scalar, tend to return
|
||
|
# a scalar even when a 0D array might be more appropriate:
|
||
|
# import numpy as np
|
||
|
# np.mean([1, 2, 3]) # scalar, not 0d array
|
||
|
# np.asarray(0)*2 # scalar, not 0d array
|
||
|
# np.sin(np.asarray(0)) # scalar, not 0d array
|
||
|
# Libraries like CuPy, for which `cp.asarray(0)[()]` returns a 0D array,
|
||
|
# tend to return a 0D array in scenarios like those above.
|
||
|
# Therefore, regardless of whether the developer provides a scalar or 0D
|
||
|
# array for `desired`, we would typically want the type of `actual` to be
|
||
|
# the type of `desired[()]`. If the developer wants to override this
|
||
|
# behavior, they can set `check_shape=False`.
|
||
|
desired = desired[()]
|
||
|
_msg = f"Types do not match:\n Actual: {type(actual)}\n Desired: {type(desired)}"
|
||
|
assert (xp.isscalar(actual) and xp.isscalar(desired)
|
||
|
or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg
|
||
|
|
||
|
|
||
|
def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True,
|
||
|
check_shape=True, err_msg='', xp=None):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
if xp is None:
|
||
|
xp = array_namespace(actual)
|
||
|
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
|
||
|
check_dtype=check_dtype, check_shape=check_shape)
|
||
|
if is_cupy(xp):
|
||
|
return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
|
||
|
elif is_torch(xp):
|
||
|
# PyTorch recommends using `rtol=0, atol=0` like this
|
||
|
# to test for exact equality
|
||
|
err_msg = None if err_msg == '' else err_msg
|
||
|
return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
|
||
|
check_dtype=False, msg=err_msg)
|
||
|
return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
|
||
|
|
||
|
|
||
|
def xp_assert_close(actual, desired, rtol=1e-07, atol=0, check_namespace=True,
|
||
|
check_dtype=True, check_shape=True, err_msg='', xp=None):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
if xp is None:
|
||
|
xp = array_namespace(actual)
|
||
|
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
|
||
|
check_dtype=check_dtype, check_shape=check_shape)
|
||
|
if is_cupy(xp):
|
||
|
return xp.testing.assert_allclose(actual, desired, rtol=rtol,
|
||
|
atol=atol, err_msg=err_msg)
|
||
|
elif is_torch(xp):
|
||
|
err_msg = None if err_msg == '' else err_msg
|
||
|
return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
|
||
|
equal_nan=True, check_dtype=False, msg=err_msg)
|
||
|
return np.testing.assert_allclose(actual, desired, rtol=rtol,
|
||
|
atol=atol, err_msg=err_msg)
|
||
|
|
||
|
|
||
|
def xp_assert_less(actual, desired, check_namespace=True, check_dtype=True,
|
||
|
check_shape=True, err_msg='', verbose=True, xp=None):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
if xp is None:
|
||
|
xp = array_namespace(actual)
|
||
|
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
|
||
|
check_dtype=check_dtype, check_shape=check_shape)
|
||
|
if is_cupy(xp):
|
||
|
return xp.testing.assert_array_less(actual, desired,
|
||
|
err_msg=err_msg, verbose=verbose)
|
||
|
elif is_torch(xp):
|
||
|
if actual.device.type != 'cpu':
|
||
|
actual = actual.cpu()
|
||
|
if desired.device.type != 'cpu':
|
||
|
desired = desired.cpu()
|
||
|
return np.testing.assert_array_less(actual, desired,
|
||
|
err_msg=err_msg, verbose=verbose)
|
||
|
|
||
|
|
||
|
def cov(x, *, xp=None):
|
||
|
if xp is None:
|
||
|
xp = array_namespace(x)
|
||
|
|
||
|
X = copy(x, xp=xp)
|
||
|
dtype = xp.result_type(X, xp.float64)
|
||
|
|
||
|
X = atleast_nd(X, ndim=2, xp=xp)
|
||
|
X = xp.asarray(X, dtype=dtype)
|
||
|
|
||
|
avg = xp.mean(X, axis=1)
|
||
|
fact = X.shape[1] - 1
|
||
|
|
||
|
if fact <= 0:
|
||
|
warnings.warn("Degrees of freedom <= 0 for slice",
|
||
|
RuntimeWarning, stacklevel=2)
|
||
|
fact = 0.0
|
||
|
|
||
|
X -= avg[:, None]
|
||
|
X_T = X.T
|
||
|
if xp.isdtype(X_T.dtype, 'complex floating'):
|
||
|
X_T = xp.conj(X_T)
|
||
|
c = X @ X_T
|
||
|
c /= fact
|
||
|
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
|
||
|
return xp.squeeze(c, axis=axes)
|
||
|
|
||
|
|
||
|
def xp_unsupported_param_msg(param):
|
||
|
return f'Providing {param!r} is only supported for numpy arrays.'
|
||
|
|
||
|
|
||
|
def is_complex(x, xp):
|
||
|
return xp.isdtype(x.dtype, 'complex floating')
|