"""Indexing mixin for sparse array/matrix classes. """ from __future__ import annotations from typing import TYPE_CHECKING import numpy as np from ._sputils import isintlike if TYPE_CHECKING: import numpy.typing as npt INT_TYPES = (int, np.integer) def _broadcast_arrays(a, b): """ Same as np.broadcast_arrays(a, b) but old writeability rules. NumPy >= 1.17.0 transitions broadcast_arrays to return read-only arrays. Set writeability explicitly to avoid warnings. Retain the old writeability rules, as our Cython code assumes the old behavior. """ x, y = np.broadcast_arrays(a, b) x.flags.writeable = a.flags.writeable y.flags.writeable = b.flags.writeable return x, y class IndexMixin: """ This class provides common dispatching and validation logic for indexing. """ def _raise_on_1d_array_slice(self): """We do not currently support 1D sparse arrays. This function is called each time that a 1D array would result, raising an error instead. Once 1D sparse arrays are implemented, it should be removed. """ from scipy.sparse import sparray if isinstance(self, sparray): raise NotImplementedError( 'We have not yet implemented 1D sparse slices; ' 'please index using explicit indices, e.g. `x[:, [0]]`' ) def __getitem__(self, key): row, col = self._validate_indices(key) # Dispatch to specialized methods. if isinstance(row, INT_TYPES): if isinstance(col, INT_TYPES): return self._get_intXint(row, col) elif isinstance(col, slice): self._raise_on_1d_array_slice() return self._get_intXslice(row, col) elif col.ndim == 1: self._raise_on_1d_array_slice() return self._get_intXarray(row, col) elif col.ndim == 2: return self._get_intXarray(row, col) raise IndexError('index results in >2 dimensions') elif isinstance(row, slice): if isinstance(col, INT_TYPES): self._raise_on_1d_array_slice() return self._get_sliceXint(row, col) elif isinstance(col, slice): if row == slice(None) and row == col: return self.copy() return self._get_sliceXslice(row, col) elif col.ndim == 1: return self._get_sliceXarray(row, col) raise IndexError('index results in >2 dimensions') elif row.ndim == 1: if isinstance(col, INT_TYPES): self._raise_on_1d_array_slice() return self._get_arrayXint(row, col) elif isinstance(col, slice): return self._get_arrayXslice(row, col) else: # row.ndim == 2 if isinstance(col, INT_TYPES): return self._get_arrayXint(row, col) elif isinstance(col, slice): raise IndexError('index results in >2 dimensions') elif row.shape[1] == 1 and (col.ndim == 1 or col.shape[0] == 1): # special case for outer indexing return self._get_columnXarray(row[:,0], col.ravel()) # The only remaining case is inner (fancy) indexing row, col = _broadcast_arrays(row, col) if row.shape != col.shape: raise IndexError('number of row and column indices differ') if row.size == 0: return self.__class__(np.atleast_2d(row).shape, dtype=self.dtype) return self._get_arrayXarray(row, col) def __setitem__(self, key, x): row, col = self._validate_indices(key) if isinstance(row, INT_TYPES) and isinstance(col, INT_TYPES): x = np.asarray(x, dtype=self.dtype) if x.size != 1: raise ValueError('Trying to assign a sequence to an item') self._set_intXint(row, col, x.flat[0]) return if isinstance(row, slice): row = np.arange(*row.indices(self.shape[0]))[:, None] else: row = np.atleast_1d(row) if isinstance(col, slice): col = np.arange(*col.indices(self.shape[1]))[None, :] if row.ndim == 1: row = row[:, None] else: col = np.atleast_1d(col) i, j = _broadcast_arrays(row, col) if i.shape != j.shape: raise IndexError('number of row and column indices differ') from ._base import issparse if issparse(x): if i.ndim == 1: # Inner indexing, so treat them like row vectors. i = i[None] j = j[None] broadcast_row = x.shape[0] == 1 and i.shape[0] != 1 broadcast_col = x.shape[1] == 1 and i.shape[1] != 1 if not ((broadcast_row or x.shape[0] == i.shape[0]) and (broadcast_col or x.shape[1] == i.shape[1])): raise ValueError('shape mismatch in assignment') if x.shape[0] == 0 or x.shape[1] == 0: return x = x.tocoo(copy=True) x.sum_duplicates() self._set_arrayXarray_sparse(i, j, x) else: # Make x and i into the same shape x = np.asarray(x, dtype=self.dtype) if x.squeeze().shape != i.squeeze().shape: x = np.broadcast_to(x, i.shape) if x.size == 0: return x = x.reshape(i.shape) self._set_arrayXarray(i, j, x) def _validate_indices(self, key): # First, check if indexing with single boolean matrix. from ._base import _spbase if (isinstance(key, (_spbase, np.ndarray)) and key.ndim == 2 and key.dtype.kind == 'b'): if key.shape != self.shape: raise IndexError('boolean index shape does not match array shape') row, col = key.nonzero() else: row, col = _unpack_index(key) M, N = self.shape def _validate_bool_idx( idx: npt.NDArray[np.bool_], axis_size: int, axis_name: str ) -> npt.NDArray[np.int_]: if len(idx) != axis_size: raise IndexError( f"boolean {axis_name} index has incorrect length: {len(idx)} " f"instead of {axis_size}" ) return _boolean_index_to_array(idx) if isintlike(row): row = int(row) if row < -M or row >= M: raise IndexError('row index (%d) out of range' % row) if row < 0: row += M elif (bool_row := _compatible_boolean_index(row)) is not None: row = _validate_bool_idx(bool_row, M, "row") elif not isinstance(row, slice): row = self._asindices(row, M) if isintlike(col): col = int(col) if col < -N or col >= N: raise IndexError('column index (%d) out of range' % col) if col < 0: col += N elif (bool_col := _compatible_boolean_index(col)) is not None: col = _validate_bool_idx(bool_col, N, "column") elif not isinstance(col, slice): col = self._asindices(col, N) return row, col def _asindices(self, idx, length): """Convert `idx` to a valid index for an axis with a given length. Subclasses that need special validation can override this method. """ try: x = np.asarray(idx) except (ValueError, TypeError, MemoryError) as e: raise IndexError('invalid index') from e if x.ndim not in (1, 2): raise IndexError('Index dimension must be 1 or 2') if x.size == 0: return x # Check bounds max_indx = x.max() if max_indx >= length: raise IndexError('index (%d) out of range' % max_indx) min_indx = x.min() if min_indx < 0: if min_indx < -length: raise IndexError('index (%d) out of range' % min_indx) if x is idx or not x.flags.owndata: x = x.copy() x[x < 0] += length return x def _getrow(self, i): """Return a copy of row i of the matrix, as a (1 x n) row vector. """ M, N = self.shape i = int(i) if i < -M or i >= M: raise IndexError('index (%d) out of range' % i) if i < 0: i += M return self._get_intXslice(i, slice(None)) def _getcol(self, i): """Return a copy of column i of the matrix, as a (m x 1) column vector. """ M, N = self.shape i = int(i) if i < -N or i >= N: raise IndexError('index (%d) out of range' % i) if i < 0: i += N return self._get_sliceXint(slice(None), i) def _get_intXint(self, row, col): raise NotImplementedError() def _get_intXarray(self, row, col): raise NotImplementedError() def _get_intXslice(self, row, col): raise NotImplementedError() def _get_sliceXint(self, row, col): raise NotImplementedError() def _get_sliceXslice(self, row, col): raise NotImplementedError() def _get_sliceXarray(self, row, col): raise NotImplementedError() def _get_arrayXint(self, row, col): raise NotImplementedError() def _get_arrayXslice(self, row, col): raise NotImplementedError() def _get_columnXarray(self, row, col): raise NotImplementedError() def _get_arrayXarray(self, row, col): raise NotImplementedError() def _set_intXint(self, row, col, x): raise NotImplementedError() def _set_arrayXarray(self, row, col, x): raise NotImplementedError() def _set_arrayXarray_sparse(self, row, col, x): # Fall back to densifying x x = np.asarray(x.toarray(), dtype=self.dtype) x, _ = _broadcast_arrays(x, row) self._set_arrayXarray(row, col, x) def _unpack_index(index) -> tuple[ int | slice | npt.NDArray[np.bool_ | np.int_], int | slice | npt.NDArray[np.bool_ | np.int_] ]: """ Parse index. Always return a tuple of the form (row, col). Valid type for row/col is integer, slice, array of bool, or array of integers. """ # Parse any ellipses. index = _check_ellipsis(index) # Next, parse the tuple or object if isinstance(index, tuple): if len(index) == 2: row, col = index elif len(index) == 1: row, col = index[0], slice(None) else: raise IndexError('invalid number of indices') else: idx = _compatible_boolean_index(index) if idx is None: row, col = index, slice(None) elif idx.ndim < 2: return idx, slice(None) elif idx.ndim == 2: return idx.nonzero() # Next, check for validity and transform the index as needed. from ._base import issparse if issparse(row) or issparse(col): # Supporting sparse boolean indexing with both row and col does # not work because spmatrix.ndim is always 2. raise IndexError( 'Indexing with sparse matrices is not supported ' 'except boolean indexing where matrix and index ' 'are equal shapes.') return row, col def _check_ellipsis(index): """Process indices with Ellipsis. Returns modified index.""" if index is Ellipsis: return (slice(None), slice(None)) if not isinstance(index, tuple): return index # Find any Ellipsis objects. ellipsis_indices = [i for i, v in enumerate(index) if v is Ellipsis] if not ellipsis_indices: return index if len(ellipsis_indices) > 1: raise IndexError("an index can only have a single ellipsis ('...')") # Replace the Ellipsis object with 0, 1, or 2 null-slices as needed. i, = ellipsis_indices num_slices = max(0, 3 - len(index)) return index[:i] + (slice(None),) * num_slices + index[i + 1:] def _maybe_bool_ndarray(idx): """Returns a compatible array if elements are boolean. """ idx = np.asanyarray(idx) if idx.dtype.kind == 'b': return idx return None def _first_element_bool(idx, max_dim=2): """Returns True if first element of the incompatible array type is boolean. """ if max_dim < 1: return None try: first = next(iter(idx), None) except TypeError: return None if isinstance(first, bool): return True return _first_element_bool(first, max_dim-1) def _compatible_boolean_index(idx): """Returns a boolean index array that can be converted to integer array. Returns None if no such array exists. """ # Presence of attribute `ndim` indicates a compatible array type. if hasattr(idx, 'ndim') or _first_element_bool(idx): return _maybe_bool_ndarray(idx) return None def _boolean_index_to_array(idx): if idx.ndim > 1: raise IndexError('invalid index shape') return np.where(idx)[0]