ai-content-maker/.venv/Lib/site-packages/thinc/backends/_custom_kernels.py

884 lines
25 KiB
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
import operator
import re
from collections import defaultdict
from functools import reduce
from pathlib import Path
from typing import Callable, Optional, Tuple
import numpy
from ..compat import cupy, has_cupy_gpu
PWD = Path(__file__).parent
KERNELS_SRC = (PWD / "_custom_kernels.cu").read_text(encoding="utf8")
KERNELS_LIST = [
"backprop_clipped_linear<double>",
"backprop_clipped_linear<float>",
"backprop_dish<double>",
"backprop_dish<float>",
"backprop_gelu<double>",
"backprop_gelu<float>",
"backprop_hard_swish<double>",
"backprop_hard_swish<float>",
"backprop_hard_swish_mobilenet<double>",
"backprop_hard_swish_mobilenet<float>",
"backprop_maxout<double>",
"backprop_maxout<float>",
"backprop_mish<double>",
"backprop_mish<float>",
"backprop_reduce_max<double>",
"backprop_reduce_max<float>",
"backprop_reduce_mean<double>",
"backprop_reduce_mean<float>",
"backprop_reduce_sum<double>",
"backprop_reduce_sum<float>",
"backprop_seq2col<double>",
"backprop_seq2col<float>",
"backprop_swish<double>",
"backprop_swish<float>",
"clipped_linear<double>",
"clipped_linear<float>",
"dish<double>",
"dish<float>",
"gather_add<double>",
"gather_add<float>",
"gelu<double>",
"gelu<float>",
"maxout<double>",
"maxout<float>",
"mish<double>",
"mish<float>",
"pad<double>",
"pad<float>",
"pad<int>",
"pad<long long>",
"reduce_max<double>",
"reduce_max<float>",
"reduce_sum<double>",
"reduce_sum<float>",
"seq2col<double>",
"seq2col<float>",
"swish<double>",
"swish<float>",
]
KERNELS = (
cupy.RawModule(
code=KERNELS_SRC, options=("--std=c++11",), name_expressions=KERNELS_LIST
)
if has_cupy_gpu
else None
)
class LazyKernel:
"""Wraps around `cupy.RawModule` and `cupy.RawKernel` to verify CuPy availability
and lazily compile the latter on first invocation.
The default CuPy behaviour triggers the compilation as soon as the `cupy.RawKernel` object
is accessed."""
name: str
_kernel: Optional["cupy.RawKernel"]
_compile_callback: Optional[Callable[[], "cupy.RawKernel"]]
__slots__ = ["name", "_kernel", "_compile_callback"]
def __init__(
self,
name: str,
*,
compile_callback: Optional[Callable[[], "cupy.RawKernel"]] = None,
) -> None:
self.name = name
self._kernel = None
self._compile_callback = compile_callback
def __call__(self, *args, **kwargs):
self._compile_kernel()
self._kernel(*args, **kwargs)
def _compile_kernel(self):
if self._kernel is not None:
return
if self._compile_callback is not None:
self._kernel = self._compile_callback()
elif KERNELS is not None:
self._kernel = KERNELS.get_function(self.name)
if self._kernel is None:
raise ValueError(f"couldn't compile Cupy kernel '{self.name}'")
def compile_mmh():
if not has_cupy_gpu:
return None
return cupy.RawKernel((PWD / "_murmur3.cu").read_text(encoding="utf8"), "hash_data")
clipped_linear_kernel_float = LazyKernel("clipped_linear<float>")
clipped_linear_kernel_double = LazyKernel("clipped_linear<double>")
dish_kernel_float = LazyKernel("dish<float>")
dish_kernel_double = LazyKernel("dish<double>")
gather_add_kernel_float = LazyKernel("gather_add<float>")
gather_add_kernel_double = LazyKernel("gather_add<double>")
gelu_kernel_float = LazyKernel("gelu<float>")
gelu_kernel_double = LazyKernel("gelu<double>")
hash_data_kernel = LazyKernel("hash_data", compile_callback=compile_mmh)
maxout_kernel_float = LazyKernel("maxout<float>")
maxout_kernel_double = LazyKernel("maxout<double>")
mish_kernel_float = LazyKernel("mish<float>")
mish_kernel_double = LazyKernel("mish<double>")
pad_kernel_float = LazyKernel("pad<float>")
pad_kernel_double = LazyKernel("pad<double>")
pad_kernel_int32 = LazyKernel("pad<int>")
pad_kernel_int64 = LazyKernel("pad<long long>")
reduce_max_kernel_float = LazyKernel("reduce_max<float>")
reduce_max_kernel_double = LazyKernel("reduce_max<double>")
reduce_sum_kernel_float = LazyKernel("reduce_sum<float>")
reduce_sum_kernel_double = LazyKernel("reduce_sum<double>")
seq2col_kernel_float = LazyKernel("seq2col<float>")
seq2col_kernel_double = LazyKernel("seq2col<double>")
swish_kernel_float = LazyKernel("swish<float>")
swish_kernel_double = LazyKernel("swish<double>")
backprop_clipped_linear_kernel_double = LazyKernel("backprop_clipped_linear<double>")
backprop_clipped_linear_kernel_float = LazyKernel("backprop_clipped_linear<float>")
backprop_dish_kernel_double = LazyKernel("backprop_dish<double>")
backprop_dish_kernel_float = LazyKernel("backprop_dish<float>")
backprop_gelu_kernel_double = LazyKernel("backprop_gelu<double>")
backprop_gelu_kernel_float = LazyKernel("backprop_gelu<float>")
backprop_hard_swish_kernel_double = LazyKernel("backprop_hard_swish<double>")
backprop_hard_swish_kernel_float = LazyKernel("backprop_hard_swish<float>")
backprop_hard_swish_mobilenet_kernel_double = LazyKernel(
"backprop_hard_swish_mobilenet<double>"
)
backprop_hard_swish_mobilenet_kernel_float = LazyKernel(
"backprop_hard_swish_mobilenet<float>"
)
backprop_maxout_kernel_double = LazyKernel("backprop_maxout<double>")
backprop_maxout_kernel_float = LazyKernel("backprop_maxout<float>")
backprop_mish_kernel_double = LazyKernel("backprop_mish<double>")
backprop_mish_kernel_float = LazyKernel("backprop_mish<float>")
backprop_reduce_max_kernel_double = LazyKernel("backprop_reduce_max<double>")
backprop_reduce_max_kernel_float = LazyKernel("backprop_reduce_max<float>")
backprop_reduce_mean_kernel_double = LazyKernel("backprop_reduce_mean<double>")
backprop_reduce_mean_kernel_float = LazyKernel("backprop_reduce_mean<float>")
backprop_reduce_sum_kernel_double = LazyKernel("backprop_reduce_sum<double>")
backprop_reduce_sum_kernel_float = LazyKernel("backprop_reduce_sum<float>")
backprop_seq2col_kernel_double = LazyKernel("backprop_seq2col<double>")
backprop_seq2col_kernel_float = LazyKernel("backprop_seq2col<float>")
backprop_swish_kernel_double = LazyKernel("backprop_swish<double>")
backprop_swish_kernel_float = LazyKernel("backprop_swish<float>")
def _alloc(shape, dtype, *, zeros: bool = True):
if zeros:
return cupy.zeros(shape, dtype)
else:
return cupy.empty(shape, dtype)
def _alloc_like(array, zeros: bool = True):
if zeros:
return cupy.zeros_like(array)
else:
return cupy.empty_like(array)
def pad(seqs, round_to=1, *, threads_per_block=128, num_blocks=128):
if round_to < 1:
raise ValueError(f"Rounding for padding must at least be 1, was: {round_to}")
for seq in seqs:
_is_float_or_int_array(seq)
seq_lens = [len(seq) for seq in seqs]
max_seq_len = max(seq_lens)
# Round the length to nearest bucket -- helps on GPU, to make similar
# array sizes.
max_seq_len += -max_seq_len % round_to
seq_lens = cupy.array(seq_lens, dtype="int32")
final_shape = (len(seqs), max_seq_len) + seqs[0].shape[1:]
out = cupy.empty(final_shape, dtype=seqs[0].dtype)
# Extract pointers from CuPy arrays, so that we can address
# them in the CUDA kernel.
ptrs = numpy.empty(
(
len(
seqs,
)
),
"int64",
)
for idx, seq in enumerate(seqs):
ptrs[idx] = seq.data.ptr
ptrs = cupy.array(ptrs)
stride = reduce(operator.mul, seqs[0].shape[1:], 1)
if out.dtype == "float32":
pad_kernel_float(
(num_blocks,),
(threads_per_block,),
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
)
elif out.dtype == "float64":
pad_kernel_double(
(num_blocks,),
(threads_per_block,),
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
)
elif out.dtype == "int32":
pad_kernel_int32(
(num_blocks,),
(threads_per_block,),
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
)
elif out.dtype == "int64":
pad_kernel_int64(
(num_blocks,),
(threads_per_block,),
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
)
return out
def clipped_linear(
X,
*,
inplace=False,
slope=1.0,
offset=0.0,
min_val=0.0,
max_val=1.0,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
clipped_linear_kernel_float(
(num_blocks,),
(threads_per_block,),
(out, X, slope, offset, min_val, max_val, X.size),
)
else:
clipped_linear_kernel_double(
(num_blocks,),
(threads_per_block,),
(out, X, slope, offset, min_val, max_val, X.size),
)
return out
def gather_add(table, indices, *, threads_per_block=128, num_blocks=128):
if table.ndim != 2:
raise ValueError(
f"gather_add expects table with dimensionality 2, was: {table.ndim}"
)
if indices.ndim != 2:
raise ValueError(
f"gather_add expects indices with dimensionality 2, was: {indices.ndim}"
)
_is_float_array(table)
indices = indices.astype("int32")
_check_indices(indices, table.shape[0])
B = indices.shape[0]
K = indices.shape[1]
T = table.shape[0]
O = table.shape[1]
out = _alloc((B, O), dtype=table.dtype, zeros=True)
if table.dtype == "float32":
gather_add_kernel_float(
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
)
else:
gather_add_kernel_double(
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
)
return out
def dish(X, *, inplace=False, threads_per_block=128, num_blocks=128):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
dish_kernel_float((num_blocks,), (threads_per_block,), (out, X, X.size))
else:
dish_kernel_double((num_blocks,), (threads_per_block,), (out, X, X.size))
return out
def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
gelu_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
else:
gelu_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
return out
def check_seq2col_lengths(lengths, B):
if lengths is None:
lengths = cupy.array([B], dtype="int32")
else:
_check_lengths(lengths, B)
return lengths
def seq2col(seq, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
_is_float_array(seq)
B = seq.shape[0]
nF = nW * 2 + 1
I = seq.shape[1]
lengths = check_seq2col_lengths(lengths, B)
nL = lengths.shape[0]
out = _alloc((B, I * nF), dtype=seq.dtype, zeros=True)
if seq.size != 0 and lengths.size != 0:
if seq.dtype == "float32":
seq2col_kernel_float(
(num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL)
)
else:
seq2col_kernel_double(
(num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL)
)
return out
def maxout(X, *, threads_per_block=128, num_blocks=128):
_is_float_array(X)
B, I, P = X.shape
out_shape = (B, I)
best = _alloc(out_shape, dtype=X.dtype, zeros=False)
which = _alloc(out_shape, dtype="i", zeros=False)
if X.dtype == "float32":
maxout_kernel_float(
(num_blocks,), (threads_per_block,), (best, which, X, B, I, P)
)
else:
maxout_kernel_double(
(num_blocks,), (threads_per_block,), (best, which, X, B, I, P)
)
return best, which
def mish(X, *, inplace=False, threshold=5, threads_per_block=128, num_blocks=128):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
mish_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
else:
mish_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
return out
def reduce_sum(X, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(X)
B = len(lengths)
T = X.shape[0]
O = X.shape[1]
_check_lengths(lengths, T)
out = _alloc((B, O), dtype=X.dtype, zeros=True)
if X.dtype == "float32":
reduce_sum_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
)
else:
reduce_sum_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
)
return out
def reduce_mean(X, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(X)
B = len(lengths)
T = X.shape[0]
O = X.shape[1]
_check_lengths(lengths, T)
out = _alloc((B, O), dtype=X.dtype, zeros=True)
if X.dtype == "float32":
reduce_sum_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
)
else:
reduce_sum_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
)
# Avoid divide by zero
out /= lengths.reshape((-1, 1)) + 1e-10
return out
def reduce_max(X, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(X)
B = len(lengths)
T = X.shape[0]
O = X.shape[1]
_check_lengths(lengths, T, min_length=1)
out_shape = (B, O)
maxes = _alloc(out_shape, dtype=X.dtype, zeros=False)
which = _alloc(out_shape, dtype="i", zeros=False)
if X.dtype == "float32":
reduce_max_kernel_float(
(num_blocks,), (threads_per_block,), (maxes, which, X, lengths, B, T, O)
)
else:
reduce_max_kernel_double(
(num_blocks,), (threads_per_block,), (maxes, which, X, lengths, B, T, O)
)
return maxes, which
def swish(X, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
swish_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
else:
swish_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
return out
def backprop_seq2col(dY, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
_is_float_array(dY)
B = dY.shape[0]
nF = nW * 2 + 1
I = dY.shape[1] // nF
lengths = check_seq2col_lengths(lengths, B)
nL = lengths.shape[0]
out = _alloc((B, I), dtype=dY.dtype, zeros=True)
if dY.size != 0 and lengths.size != 0:
if dY.dtype == "float32":
backprop_seq2col_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, lengths, nW, B, I, nL)
)
else:
backprop_seq2col_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, lengths, nW, B, I, nL)
)
return out
def backprop_clipped_linear(
dY,
X,
*,
slope: float = 1.0,
offset: float = 0.0,
min_val: float = 0.0,
max_val: float = 1.0,
inplace: bool = False,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_clipped_linear_kernel_float(
(num_blocks,),
(threads_per_block,),
(out, dY, X, slope, offset, min_val, max_val, out.size),
)
else:
backprop_clipped_linear_kernel_double(
(num_blocks,),
(threads_per_block,),
(out, dY, X, slope, offset, min_val, max_val, out.size),
)
return out
def backprop_hard_swish(
dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_hard_swish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
else:
backprop_hard_swish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
return out
def backprop_hard_swish_mobilenet(
dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_hard_swish_mobilenet_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
else:
backprop_hard_swish_mobilenet_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
return out
def backprop_dish(
dY,
X,
*,
inplace: bool = False,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_dish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
else:
backprop_dish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
return out
def backprop_gelu(
dY,
X,
*,
inplace: bool = False,
threshold=6.0,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_gelu_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, out.size)
)
else:
backprop_gelu_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, out.size)
)
return out
def backprop_maxout(dY, which, P, *, threads_per_block=128, num_blocks=128):
_is_float_array(dY)
B = dY.shape[0]
I = dY.shape[1]
out = _alloc((B, I, P), dtype=dY.dtype, zeros=True)
_check_which_maxout(which, B, I, P)
if dY.dtype == "float32":
backprop_maxout_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, which, B, I, P)
)
else:
backprop_maxout_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, which, B, I, P)
)
return out
def backprop_mish(
dY, X, *, inplace: bool = False, threshold=5, threads_per_block=128, num_blocks=128
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_mish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, dY.size)
)
else:
backprop_mish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, dY.size)
)
return out
def backprop_reduce_sum(d_sums, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(d_sums)
B = len(lengths)
T = int(lengths.sum())
O = d_sums.shape[1]
_check_lengths(lengths, T)
out = _alloc((T, O), dtype=d_sums.dtype, zeros=False)
if d_sums.dtype == "float32":
backprop_reduce_sum_kernel_float(
(num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O)
)
else:
backprop_reduce_sum_kernel_double(
(num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O)
)
return out
def backprop_reduce_mean(d_means, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(d_means)
B = len(lengths)
T = int(lengths.sum())
O = d_means.shape[1]
_check_lengths(lengths, T)
out = _alloc((T, O), dtype=d_means.dtype, zeros=False)
if d_means.dtype == "float32":
backprop_reduce_mean_kernel_float(
(num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O)
)
else:
backprop_reduce_mean_kernel_double(
(num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O)
)
return out
def backprop_reduce_max(
d_maxes, which, lengths, *, threads_per_block=128, num_blocks=128
):
_is_float_array(d_maxes)
B = len(lengths)
T = int(lengths.sum())
O = d_maxes.shape[1]
_check_lengths(lengths, T, min_length=1)
out = _alloc((T, O), dtype=d_maxes.dtype, zeros=True)
_check_which_reduce_max(which, (B, O), lengths)
if d_maxes.dtype == "float32":
backprop_reduce_max_kernel_float(
(num_blocks,), (threads_per_block,), (out, d_maxes, which, lengths, B, T, O)
)
else:
backprop_reduce_max_kernel_double(
(num_blocks,), (threads_per_block,), (out, d_maxes, which, lengths, B, T, O)
)
return out
def backprop_swish(
dY, X, Y, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
_is_float_array(Y, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_swish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, Y, threshold, out.size)
)
else:
backprop_swish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, Y, threshold, out.size)
)
return out
def hash(ids, seed, *, threads_per_block=128, num_blocks=128):
out = _alloc((ids.shape[0], 4), dtype="uint32", zeros=True)
# sizeof(uint32_t) * 4
out_size = 4 * 4
in_size = 8 # sizeof(uint64_t)
# T = ids.shape[0]
hash_data_kernel(
(num_blocks,),
(threads_per_block,),
(out, ids, out_size, in_size, ids.shape[0], seed),
)
return out
def _is_float_array(out, *, shape: Optional[Tuple] = None):
assert out.dtype in (
"float32",
"float64",
), "CUDA kernel can only handle float32 and float64"
if shape is not None and out.shape != shape:
msg = f"array has incorrect shape, expected: {shape}, was: {out.shape}"
raise ValueError(msg)
def _is_float_or_int_array(out, *, shape: Optional[Tuple] = None):
assert out.dtype in (
"float32",
"float64",
"int32",
"int64",
), "CUDA kernel can only handle float32, float64, int32 and int64"
if shape is not None and out.shape != shape:
msg = f"array has incorrect shape, expected: {shape}, was: {out.shape}"
raise ValueError(msg)
def _check_lengths(lengths, n_elems: int, *, min_length=0):
assert lengths.dtype == "int32", "lengths should be encoded as 32-bit integers"
if not cupy.all(lengths >= min_length):
raise ValueError(f"all sequence lengths must be >= {min_length}")
if cupy.sum(lengths) != n_elems:
raise IndexError("lengths must sum up to the batch size")
def _check_indices(indices, n: int):
assert indices.dtype == "int32", "indices should be encoded as 32-bit integers"
if not _values_within_range(indices, 0, n):
raise IndexError(f"index out of bounds, must be >= 0 && < {n}")
def _check_which_maxout(which, B: int, I: int, P: int):
shape = (B, I)
msg = "maximum index (which) should be encoded as 32-bit integers"
assert which.dtype == "int32", msg
if which.shape != shape:
msg = f"maximum index (which) has incorrect shape, expected: {shape}, was: {which.shape}"
raise ValueError(msg)
if not _values_within_range(which, 0, P):
raise IndexError("maximum index (which) value out of bounds")
_values_within_range = (
cupy.ReductionKernel(
"T x, T lower, T upper",
"bool r",
"x >= lower && x < upper",
"a & b",
"r = a",
"true",
"within_range",
)
if has_cupy_gpu
else None
)
def _check_which_reduce_max(which, shape: Tuple, lengths):
msg = "maximum index (which) should be encoded as 32-bit integers"
assert which.dtype == "int32", msg
if which.shape != shape:
msg = f"maximum index (which) has incorrect shape, expected: {shape}, was: {which.shape}"
raise ValueError(msg)
if not cupy.all((which >= 0) & (which < cupy.expand_dims(lengths, -1))):
raise IndexError("maximum index (which) value out of bounds")