ai-content-maker/.venv/Lib/site-packages/thinc/backends/_custom_kernels.py

884 lines
25 KiB
Python

import operator
import re
from collections import defaultdict
from functools import reduce
from pathlib import Path
from typing import Callable, Optional, Tuple
import numpy
from ..compat import cupy, has_cupy_gpu
PWD = Path(__file__).parent
KERNELS_SRC = (PWD / "_custom_kernels.cu").read_text(encoding="utf8")
KERNELS_LIST = [
"backprop_clipped_linear<double>",
"backprop_clipped_linear<float>",
"backprop_dish<double>",
"backprop_dish<float>",
"backprop_gelu<double>",
"backprop_gelu<float>",
"backprop_hard_swish<double>",
"backprop_hard_swish<float>",
"backprop_hard_swish_mobilenet<double>",
"backprop_hard_swish_mobilenet<float>",
"backprop_maxout<double>",
"backprop_maxout<float>",
"backprop_mish<double>",
"backprop_mish<float>",
"backprop_reduce_max<double>",
"backprop_reduce_max<float>",
"backprop_reduce_mean<double>",
"backprop_reduce_mean<float>",
"backprop_reduce_sum<double>",
"backprop_reduce_sum<float>",
"backprop_seq2col<double>",
"backprop_seq2col<float>",
"backprop_swish<double>",
"backprop_swish<float>",
"clipped_linear<double>",
"clipped_linear<float>",
"dish<double>",
"dish<float>",
"gather_add<double>",
"gather_add<float>",
"gelu<double>",
"gelu<float>",
"maxout<double>",
"maxout<float>",
"mish<double>",
"mish<float>",
"pad<double>",
"pad<float>",
"pad<int>",
"pad<long long>",
"reduce_max<double>",
"reduce_max<float>",
"reduce_sum<double>",
"reduce_sum<float>",
"seq2col<double>",
"seq2col<float>",
"swish<double>",
"swish<float>",
]
KERNELS = (
cupy.RawModule(
code=KERNELS_SRC, options=("--std=c++11",), name_expressions=KERNELS_LIST
)
if has_cupy_gpu
else None
)
class LazyKernel:
"""Wraps around `cupy.RawModule` and `cupy.RawKernel` to verify CuPy availability
and lazily compile the latter on first invocation.
The default CuPy behaviour triggers the compilation as soon as the `cupy.RawKernel` object
is accessed."""
name: str
_kernel: Optional["cupy.RawKernel"]
_compile_callback: Optional[Callable[[], "cupy.RawKernel"]]
__slots__ = ["name", "_kernel", "_compile_callback"]
def __init__(
self,
name: str,
*,
compile_callback: Optional[Callable[[], "cupy.RawKernel"]] = None,
) -> None:
self.name = name
self._kernel = None
self._compile_callback = compile_callback
def __call__(self, *args, **kwargs):
self._compile_kernel()
self._kernel(*args, **kwargs)
def _compile_kernel(self):
if self._kernel is not None:
return
if self._compile_callback is not None:
self._kernel = self._compile_callback()
elif KERNELS is not None:
self._kernel = KERNELS.get_function(self.name)
if self._kernel is None:
raise ValueError(f"couldn't compile Cupy kernel '{self.name}'")
def compile_mmh():
if not has_cupy_gpu:
return None
return cupy.RawKernel((PWD / "_murmur3.cu").read_text(encoding="utf8"), "hash_data")
clipped_linear_kernel_float = LazyKernel("clipped_linear<float>")
clipped_linear_kernel_double = LazyKernel("clipped_linear<double>")
dish_kernel_float = LazyKernel("dish<float>")
dish_kernel_double = LazyKernel("dish<double>")
gather_add_kernel_float = LazyKernel("gather_add<float>")
gather_add_kernel_double = LazyKernel("gather_add<double>")
gelu_kernel_float = LazyKernel("gelu<float>")
gelu_kernel_double = LazyKernel("gelu<double>")
hash_data_kernel = LazyKernel("hash_data", compile_callback=compile_mmh)
maxout_kernel_float = LazyKernel("maxout<float>")
maxout_kernel_double = LazyKernel("maxout<double>")
mish_kernel_float = LazyKernel("mish<float>")
mish_kernel_double = LazyKernel("mish<double>")
pad_kernel_float = LazyKernel("pad<float>")
pad_kernel_double = LazyKernel("pad<double>")
pad_kernel_int32 = LazyKernel("pad<int>")
pad_kernel_int64 = LazyKernel("pad<long long>")
reduce_max_kernel_float = LazyKernel("reduce_max<float>")
reduce_max_kernel_double = LazyKernel("reduce_max<double>")
reduce_sum_kernel_float = LazyKernel("reduce_sum<float>")
reduce_sum_kernel_double = LazyKernel("reduce_sum<double>")
seq2col_kernel_float = LazyKernel("seq2col<float>")
seq2col_kernel_double = LazyKernel("seq2col<double>")
swish_kernel_float = LazyKernel("swish<float>")
swish_kernel_double = LazyKernel("swish<double>")
backprop_clipped_linear_kernel_double = LazyKernel("backprop_clipped_linear<double>")
backprop_clipped_linear_kernel_float = LazyKernel("backprop_clipped_linear<float>")
backprop_dish_kernel_double = LazyKernel("backprop_dish<double>")
backprop_dish_kernel_float = LazyKernel("backprop_dish<float>")
backprop_gelu_kernel_double = LazyKernel("backprop_gelu<double>")
backprop_gelu_kernel_float = LazyKernel("backprop_gelu<float>")
backprop_hard_swish_kernel_double = LazyKernel("backprop_hard_swish<double>")
backprop_hard_swish_kernel_float = LazyKernel("backprop_hard_swish<float>")
backprop_hard_swish_mobilenet_kernel_double = LazyKernel(
"backprop_hard_swish_mobilenet<double>"
)
backprop_hard_swish_mobilenet_kernel_float = LazyKernel(
"backprop_hard_swish_mobilenet<float>"
)
backprop_maxout_kernel_double = LazyKernel("backprop_maxout<double>")
backprop_maxout_kernel_float = LazyKernel("backprop_maxout<float>")
backprop_mish_kernel_double = LazyKernel("backprop_mish<double>")
backprop_mish_kernel_float = LazyKernel("backprop_mish<float>")
backprop_reduce_max_kernel_double = LazyKernel("backprop_reduce_max<double>")
backprop_reduce_max_kernel_float = LazyKernel("backprop_reduce_max<float>")
backprop_reduce_mean_kernel_double = LazyKernel("backprop_reduce_mean<double>")
backprop_reduce_mean_kernel_float = LazyKernel("backprop_reduce_mean<float>")
backprop_reduce_sum_kernel_double = LazyKernel("backprop_reduce_sum<double>")
backprop_reduce_sum_kernel_float = LazyKernel("backprop_reduce_sum<float>")
backprop_seq2col_kernel_double = LazyKernel("backprop_seq2col<double>")
backprop_seq2col_kernel_float = LazyKernel("backprop_seq2col<float>")
backprop_swish_kernel_double = LazyKernel("backprop_swish<double>")
backprop_swish_kernel_float = LazyKernel("backprop_swish<float>")
def _alloc(shape, dtype, *, zeros: bool = True):
if zeros:
return cupy.zeros(shape, dtype)
else:
return cupy.empty(shape, dtype)
def _alloc_like(array, zeros: bool = True):
if zeros:
return cupy.zeros_like(array)
else:
return cupy.empty_like(array)
def pad(seqs, round_to=1, *, threads_per_block=128, num_blocks=128):
if round_to < 1:
raise ValueError(f"Rounding for padding must at least be 1, was: {round_to}")
for seq in seqs:
_is_float_or_int_array(seq)
seq_lens = [len(seq) for seq in seqs]
max_seq_len = max(seq_lens)
# Round the length to nearest bucket -- helps on GPU, to make similar
# array sizes.
max_seq_len += -max_seq_len % round_to
seq_lens = cupy.array(seq_lens, dtype="int32")
final_shape = (len(seqs), max_seq_len) + seqs[0].shape[1:]
out = cupy.empty(final_shape, dtype=seqs[0].dtype)
# Extract pointers from CuPy arrays, so that we can address
# them in the CUDA kernel.
ptrs = numpy.empty(
(
len(
seqs,
)
),
"int64",
)
for idx, seq in enumerate(seqs):
ptrs[idx] = seq.data.ptr
ptrs = cupy.array(ptrs)
stride = reduce(operator.mul, seqs[0].shape[1:], 1)
if out.dtype == "float32":
pad_kernel_float(
(num_blocks,),
(threads_per_block,),
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
)
elif out.dtype == "float64":
pad_kernel_double(
(num_blocks,),
(threads_per_block,),
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
)
elif out.dtype == "int32":
pad_kernel_int32(
(num_blocks,),
(threads_per_block,),
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
)
elif out.dtype == "int64":
pad_kernel_int64(
(num_blocks,),
(threads_per_block,),
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
)
return out
def clipped_linear(
X,
*,
inplace=False,
slope=1.0,
offset=0.0,
min_val=0.0,
max_val=1.0,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
clipped_linear_kernel_float(
(num_blocks,),
(threads_per_block,),
(out, X, slope, offset, min_val, max_val, X.size),
)
else:
clipped_linear_kernel_double(
(num_blocks,),
(threads_per_block,),
(out, X, slope, offset, min_val, max_val, X.size),
)
return out
def gather_add(table, indices, *, threads_per_block=128, num_blocks=128):
if table.ndim != 2:
raise ValueError(
f"gather_add expects table with dimensionality 2, was: {table.ndim}"
)
if indices.ndim != 2:
raise ValueError(
f"gather_add expects indices with dimensionality 2, was: {indices.ndim}"
)
_is_float_array(table)
indices = indices.astype("int32")
_check_indices(indices, table.shape[0])
B = indices.shape[0]
K = indices.shape[1]
T = table.shape[0]
O = table.shape[1]
out = _alloc((B, O), dtype=table.dtype, zeros=True)
if table.dtype == "float32":
gather_add_kernel_float(
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
)
else:
gather_add_kernel_double(
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
)
return out
def dish(X, *, inplace=False, threads_per_block=128, num_blocks=128):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
dish_kernel_float((num_blocks,), (threads_per_block,), (out, X, X.size))
else:
dish_kernel_double((num_blocks,), (threads_per_block,), (out, X, X.size))
return out
def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
gelu_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
else:
gelu_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
return out
def check_seq2col_lengths(lengths, B):
if lengths is None:
lengths = cupy.array([B], dtype="int32")
else:
_check_lengths(lengths, B)
return lengths
def seq2col(seq, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
_is_float_array(seq)
B = seq.shape[0]
nF = nW * 2 + 1
I = seq.shape[1]
lengths = check_seq2col_lengths(lengths, B)
nL = lengths.shape[0]
out = _alloc((B, I * nF), dtype=seq.dtype, zeros=True)
if seq.size != 0 and lengths.size != 0:
if seq.dtype == "float32":
seq2col_kernel_float(
(num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL)
)
else:
seq2col_kernel_double(
(num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL)
)
return out
def maxout(X, *, threads_per_block=128, num_blocks=128):
_is_float_array(X)
B, I, P = X.shape
out_shape = (B, I)
best = _alloc(out_shape, dtype=X.dtype, zeros=False)
which = _alloc(out_shape, dtype="i", zeros=False)
if X.dtype == "float32":
maxout_kernel_float(
(num_blocks,), (threads_per_block,), (best, which, X, B, I, P)
)
else:
maxout_kernel_double(
(num_blocks,), (threads_per_block,), (best, which, X, B, I, P)
)
return best, which
def mish(X, *, inplace=False, threshold=5, threads_per_block=128, num_blocks=128):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
mish_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
else:
mish_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
return out
def reduce_sum(X, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(X)
B = len(lengths)
T = X.shape[0]
O = X.shape[1]
_check_lengths(lengths, T)
out = _alloc((B, O), dtype=X.dtype, zeros=True)
if X.dtype == "float32":
reduce_sum_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
)
else:
reduce_sum_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
)
return out
def reduce_mean(X, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(X)
B = len(lengths)
T = X.shape[0]
O = X.shape[1]
_check_lengths(lengths, T)
out = _alloc((B, O), dtype=X.dtype, zeros=True)
if X.dtype == "float32":
reduce_sum_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
)
else:
reduce_sum_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
)
# Avoid divide by zero
out /= lengths.reshape((-1, 1)) + 1e-10
return out
def reduce_max(X, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(X)
B = len(lengths)
T = X.shape[0]
O = X.shape[1]
_check_lengths(lengths, T, min_length=1)
out_shape = (B, O)
maxes = _alloc(out_shape, dtype=X.dtype, zeros=False)
which = _alloc(out_shape, dtype="i", zeros=False)
if X.dtype == "float32":
reduce_max_kernel_float(
(num_blocks,), (threads_per_block,), (maxes, which, X, lengths, B, T, O)
)
else:
reduce_max_kernel_double(
(num_blocks,), (threads_per_block,), (maxes, which, X, lengths, B, T, O)
)
return maxes, which
def swish(X, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128):
_is_float_array(X)
out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
swish_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
else:
swish_kernel_double(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
)
return out
def backprop_seq2col(dY, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
_is_float_array(dY)
B = dY.shape[0]
nF = nW * 2 + 1
I = dY.shape[1] // nF
lengths = check_seq2col_lengths(lengths, B)
nL = lengths.shape[0]
out = _alloc((B, I), dtype=dY.dtype, zeros=True)
if dY.size != 0 and lengths.size != 0:
if dY.dtype == "float32":
backprop_seq2col_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, lengths, nW, B, I, nL)
)
else:
backprop_seq2col_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, lengths, nW, B, I, nL)
)
return out
def backprop_clipped_linear(
dY,
X,
*,
slope: float = 1.0,
offset: float = 0.0,
min_val: float = 0.0,
max_val: float = 1.0,
inplace: bool = False,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_clipped_linear_kernel_float(
(num_blocks,),
(threads_per_block,),
(out, dY, X, slope, offset, min_val, max_val, out.size),
)
else:
backprop_clipped_linear_kernel_double(
(num_blocks,),
(threads_per_block,),
(out, dY, X, slope, offset, min_val, max_val, out.size),
)
return out
def backprop_hard_swish(
dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_hard_swish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
else:
backprop_hard_swish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
return out
def backprop_hard_swish_mobilenet(
dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_hard_swish_mobilenet_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
else:
backprop_hard_swish_mobilenet_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
return out
def backprop_dish(
dY,
X,
*,
inplace: bool = False,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_dish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
else:
backprop_dish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
return out
def backprop_gelu(
dY,
X,
*,
inplace: bool = False,
threshold=6.0,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_gelu_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, out.size)
)
else:
backprop_gelu_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, out.size)
)
return out
def backprop_maxout(dY, which, P, *, threads_per_block=128, num_blocks=128):
_is_float_array(dY)
B = dY.shape[0]
I = dY.shape[1]
out = _alloc((B, I, P), dtype=dY.dtype, zeros=True)
_check_which_maxout(which, B, I, P)
if dY.dtype == "float32":
backprop_maxout_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, which, B, I, P)
)
else:
backprop_maxout_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, which, B, I, P)
)
return out
def backprop_mish(
dY, X, *, inplace: bool = False, threshold=5, threads_per_block=128, num_blocks=128
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_mish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, dY.size)
)
else:
backprop_mish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, dY.size)
)
return out
def backprop_reduce_sum(d_sums, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(d_sums)
B = len(lengths)
T = int(lengths.sum())
O = d_sums.shape[1]
_check_lengths(lengths, T)
out = _alloc((T, O), dtype=d_sums.dtype, zeros=False)
if d_sums.dtype == "float32":
backprop_reduce_sum_kernel_float(
(num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O)
)
else:
backprop_reduce_sum_kernel_double(
(num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O)
)
return out
def backprop_reduce_mean(d_means, lengths, *, threads_per_block=128, num_blocks=128):
_is_float_array(d_means)
B = len(lengths)
T = int(lengths.sum())
O = d_means.shape[1]
_check_lengths(lengths, T)
out = _alloc((T, O), dtype=d_means.dtype, zeros=False)
if d_means.dtype == "float32":
backprop_reduce_mean_kernel_float(
(num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O)
)
else:
backprop_reduce_mean_kernel_double(
(num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O)
)
return out
def backprop_reduce_max(
d_maxes, which, lengths, *, threads_per_block=128, num_blocks=128
):
_is_float_array(d_maxes)
B = len(lengths)
T = int(lengths.sum())
O = d_maxes.shape[1]
_check_lengths(lengths, T, min_length=1)
out = _alloc((T, O), dtype=d_maxes.dtype, zeros=True)
_check_which_reduce_max(which, (B, O), lengths)
if d_maxes.dtype == "float32":
backprop_reduce_max_kernel_float(
(num_blocks,), (threads_per_block,), (out, d_maxes, which, lengths, B, T, O)
)
else:
backprop_reduce_max_kernel_double(
(num_blocks,), (threads_per_block,), (out, d_maxes, which, lengths, B, T, O)
)
return out
def backprop_swish(
dY, X, Y, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)
_is_float_array(Y, shape=dY.shape)
out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)
if dY.dtype == "float32":
backprop_swish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, Y, threshold, out.size)
)
else:
backprop_swish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, Y, threshold, out.size)
)
return out
def hash(ids, seed, *, threads_per_block=128, num_blocks=128):
out = _alloc((ids.shape[0], 4), dtype="uint32", zeros=True)
# sizeof(uint32_t) * 4
out_size = 4 * 4
in_size = 8 # sizeof(uint64_t)
# T = ids.shape[0]
hash_data_kernel(
(num_blocks,),
(threads_per_block,),
(out, ids, out_size, in_size, ids.shape[0], seed),
)
return out
def _is_float_array(out, *, shape: Optional[Tuple] = None):
assert out.dtype in (
"float32",
"float64",
), "CUDA kernel can only handle float32 and float64"
if shape is not None and out.shape != shape:
msg = f"array has incorrect shape, expected: {shape}, was: {out.shape}"
raise ValueError(msg)
def _is_float_or_int_array(out, *, shape: Optional[Tuple] = None):
assert out.dtype in (
"float32",
"float64",
"int32",
"int64",
), "CUDA kernel can only handle float32, float64, int32 and int64"
if shape is not None and out.shape != shape:
msg = f"array has incorrect shape, expected: {shape}, was: {out.shape}"
raise ValueError(msg)
def _check_lengths(lengths, n_elems: int, *, min_length=0):
assert lengths.dtype == "int32", "lengths should be encoded as 32-bit integers"
if not cupy.all(lengths >= min_length):
raise ValueError(f"all sequence lengths must be >= {min_length}")
if cupy.sum(lengths) != n_elems:
raise IndexError("lengths must sum up to the batch size")
def _check_indices(indices, n: int):
assert indices.dtype == "int32", "indices should be encoded as 32-bit integers"
if not _values_within_range(indices, 0, n):
raise IndexError(f"index out of bounds, must be >= 0 && < {n}")
def _check_which_maxout(which, B: int, I: int, P: int):
shape = (B, I)
msg = "maximum index (which) should be encoded as 32-bit integers"
assert which.dtype == "int32", msg
if which.shape != shape:
msg = f"maximum index (which) has incorrect shape, expected: {shape}, was: {which.shape}"
raise ValueError(msg)
if not _values_within_range(which, 0, P):
raise IndexError("maximum index (which) value out of bounds")
_values_within_range = (
cupy.ReductionKernel(
"T x, T lower, T upper",
"bool r",
"x >= lower && x < upper",
"a & b",
"r = a",
"true",
"within_range",
)
if has_cupy_gpu
else None
)
def _check_which_reduce_max(which, shape: Tuple, lengths):
msg = "maximum index (which) should be encoded as 32-bit integers"
assert which.dtype == "int32", msg
if which.shape != shape:
msg = f"maximum index (which) has incorrect shape, expected: {shape}, was: {which.shape}"
raise ValueError(msg)
if not cupy.all((which >= 0) & (which < cupy.expand_dims(lengths, -1))):
raise IndexError("maximum index (which) value out of bounds")