52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
from typing import Sequence, Tuple, TypeVar, Union
|
|
|
|
from ..model import Model
|
|
from ..types import ArrayXd, FloatsXd, IntsXd
|
|
|
|
AxisIndex = Union[int, slice, Sequence[int]]
|
|
Index = Union[AxisIndex, Tuple[AxisIndex, ...]]
|
|
ArrayTXd = TypeVar("ArrayTXd", bound=ArrayXd)
|
|
|
|
|
|
def array_getitem(index: Index) -> Model[ArrayTXd, ArrayTXd]:
|
|
"""Index into input arrays, and return the subarrays.
|
|
|
|
index:
|
|
A valid numpy-style index. Multi-dimensional indexing can be performed
|
|
by passing in a tuple, and slicing can be performed using the slice object.
|
|
For instance, X[:, :-1] would be (slice(None, None), slice(None, -1)).
|
|
"""
|
|
return Model("array-getitem", forward, attrs={"index": index})
|
|
|
|
|
|
def floats_getitem(index: Index) -> Model[FloatsXd, FloatsXd]:
|
|
"""Index into input arrays, and return the subarrays.
|
|
|
|
This delegates to `array_getitem`, but allows type declarations.
|
|
"""
|
|
return Model("floats-getitem", forward, attrs={"index": index})
|
|
|
|
|
|
def ints_getitem(index: Index) -> Model[IntsXd, IntsXd]:
|
|
"""Index into input arrays, and return the subarrays.
|
|
|
|
This delegates to `array_getitem`, but allows type declarations.
|
|
"""
|
|
return Model("ints-getitem", forward, attrs={"index": index})
|
|
|
|
|
|
def forward(model, X, is_train):
|
|
index = model.attrs["index"]
|
|
shape = X.shape
|
|
dtype = X.dtype
|
|
|
|
def backprop_get_column(dY):
|
|
dX = model.ops.alloc(shape, dtype=dtype)
|
|
dX[index] = dY
|
|
return dX
|
|
|
|
if len(X) == 0:
|
|
return X, backprop_get_column
|
|
Y = X[index]
|
|
return Y, backprop_get_column
|