884 lines
25 KiB
Python
884 lines
25 KiB
Python
import operator
|
|
import re
|
|
from collections import defaultdict
|
|
from functools import reduce
|
|
from pathlib import Path
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
import numpy
|
|
|
|
from ..compat import cupy, has_cupy_gpu
|
|
|
|
PWD = Path(__file__).parent
|
|
KERNELS_SRC = (PWD / "_custom_kernels.cu").read_text(encoding="utf8")
|
|
KERNELS_LIST = [
|
|
"backprop_clipped_linear<double>",
|
|
"backprop_clipped_linear<float>",
|
|
"backprop_dish<double>",
|
|
"backprop_dish<float>",
|
|
"backprop_gelu<double>",
|
|
"backprop_gelu<float>",
|
|
"backprop_hard_swish<double>",
|
|
"backprop_hard_swish<float>",
|
|
"backprop_hard_swish_mobilenet<double>",
|
|
"backprop_hard_swish_mobilenet<float>",
|
|
"backprop_maxout<double>",
|
|
"backprop_maxout<float>",
|
|
"backprop_mish<double>",
|
|
"backprop_mish<float>",
|
|
"backprop_reduce_max<double>",
|
|
"backprop_reduce_max<float>",
|
|
"backprop_reduce_mean<double>",
|
|
"backprop_reduce_mean<float>",
|
|
"backprop_reduce_sum<double>",
|
|
"backprop_reduce_sum<float>",
|
|
"backprop_seq2col<double>",
|
|
"backprop_seq2col<float>",
|
|
"backprop_swish<double>",
|
|
"backprop_swish<float>",
|
|
"clipped_linear<double>",
|
|
"clipped_linear<float>",
|
|
"dish<double>",
|
|
"dish<float>",
|
|
"gather_add<double>",
|
|
"gather_add<float>",
|
|
"gelu<double>",
|
|
"gelu<float>",
|
|
"maxout<double>",
|
|
"maxout<float>",
|
|
"mish<double>",
|
|
"mish<float>",
|
|
"pad<double>",
|
|
"pad<float>",
|
|
"pad<int>",
|
|
"pad<long long>",
|
|
"reduce_max<double>",
|
|
"reduce_max<float>",
|
|
"reduce_sum<double>",
|
|
"reduce_sum<float>",
|
|
"seq2col<double>",
|
|
"seq2col<float>",
|
|
"swish<double>",
|
|
"swish<float>",
|
|
]
|
|
KERNELS = (
|
|
cupy.RawModule(
|
|
code=KERNELS_SRC, options=("--std=c++11",), name_expressions=KERNELS_LIST
|
|
)
|
|
if has_cupy_gpu
|
|
else None
|
|
)
|
|
|
|
|
|
class LazyKernel:
|
|
"""Wraps around `cupy.RawModule` and `cupy.RawKernel` to verify CuPy availability
|
|
and lazily compile the latter on first invocation.
|
|
|
|
The default CuPy behaviour triggers the compilation as soon as the `cupy.RawKernel` object
|
|
is accessed."""
|
|
|
|
name: str
|
|
_kernel: Optional["cupy.RawKernel"]
|
|
_compile_callback: Optional[Callable[[], "cupy.RawKernel"]]
|
|
|
|
__slots__ = ["name", "_kernel", "_compile_callback"]
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
*,
|
|
compile_callback: Optional[Callable[[], "cupy.RawKernel"]] = None,
|
|
) -> None:
|
|
self.name = name
|
|
self._kernel = None
|
|
self._compile_callback = compile_callback
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
self._compile_kernel()
|
|
self._kernel(*args, **kwargs)
|
|
|
|
def _compile_kernel(self):
|
|
if self._kernel is not None:
|
|
return
|
|
|
|
if self._compile_callback is not None:
|
|
self._kernel = self._compile_callback()
|
|
elif KERNELS is not None:
|
|
self._kernel = KERNELS.get_function(self.name)
|
|
|
|
if self._kernel is None:
|
|
raise ValueError(f"couldn't compile Cupy kernel '{self.name}'")
|
|
|
|
|
|
def compile_mmh():
|
|
if not has_cupy_gpu:
|
|
return None
|
|
return cupy.RawKernel((PWD / "_murmur3.cu").read_text(encoding="utf8"), "hash_data")
|
|
|
|
|
|
clipped_linear_kernel_float = LazyKernel("clipped_linear<float>")
|
|
clipped_linear_kernel_double = LazyKernel("clipped_linear<double>")
|
|
dish_kernel_float = LazyKernel("dish<float>")
|
|
dish_kernel_double = LazyKernel("dish<double>")
|
|
gather_add_kernel_float = LazyKernel("gather_add<float>")
|
|
gather_add_kernel_double = LazyKernel("gather_add<double>")
|
|
gelu_kernel_float = LazyKernel("gelu<float>")
|
|
gelu_kernel_double = LazyKernel("gelu<double>")
|
|
hash_data_kernel = LazyKernel("hash_data", compile_callback=compile_mmh)
|
|
maxout_kernel_float = LazyKernel("maxout<float>")
|
|
maxout_kernel_double = LazyKernel("maxout<double>")
|
|
mish_kernel_float = LazyKernel("mish<float>")
|
|
mish_kernel_double = LazyKernel("mish<double>")
|
|
pad_kernel_float = LazyKernel("pad<float>")
|
|
pad_kernel_double = LazyKernel("pad<double>")
|
|
pad_kernel_int32 = LazyKernel("pad<int>")
|
|
pad_kernel_int64 = LazyKernel("pad<long long>")
|
|
reduce_max_kernel_float = LazyKernel("reduce_max<float>")
|
|
reduce_max_kernel_double = LazyKernel("reduce_max<double>")
|
|
reduce_sum_kernel_float = LazyKernel("reduce_sum<float>")
|
|
reduce_sum_kernel_double = LazyKernel("reduce_sum<double>")
|
|
seq2col_kernel_float = LazyKernel("seq2col<float>")
|
|
seq2col_kernel_double = LazyKernel("seq2col<double>")
|
|
swish_kernel_float = LazyKernel("swish<float>")
|
|
swish_kernel_double = LazyKernel("swish<double>")
|
|
|
|
backprop_clipped_linear_kernel_double = LazyKernel("backprop_clipped_linear<double>")
|
|
backprop_clipped_linear_kernel_float = LazyKernel("backprop_clipped_linear<float>")
|
|
backprop_dish_kernel_double = LazyKernel("backprop_dish<double>")
|
|
backprop_dish_kernel_float = LazyKernel("backprop_dish<float>")
|
|
backprop_gelu_kernel_double = LazyKernel("backprop_gelu<double>")
|
|
backprop_gelu_kernel_float = LazyKernel("backprop_gelu<float>")
|
|
backprop_hard_swish_kernel_double = LazyKernel("backprop_hard_swish<double>")
|
|
backprop_hard_swish_kernel_float = LazyKernel("backprop_hard_swish<float>")
|
|
backprop_hard_swish_mobilenet_kernel_double = LazyKernel(
|
|
"backprop_hard_swish_mobilenet<double>"
|
|
)
|
|
backprop_hard_swish_mobilenet_kernel_float = LazyKernel(
|
|
"backprop_hard_swish_mobilenet<float>"
|
|
)
|
|
backprop_maxout_kernel_double = LazyKernel("backprop_maxout<double>")
|
|
backprop_maxout_kernel_float = LazyKernel("backprop_maxout<float>")
|
|
backprop_mish_kernel_double = LazyKernel("backprop_mish<double>")
|
|
backprop_mish_kernel_float = LazyKernel("backprop_mish<float>")
|
|
backprop_reduce_max_kernel_double = LazyKernel("backprop_reduce_max<double>")
|
|
backprop_reduce_max_kernel_float = LazyKernel("backprop_reduce_max<float>")
|
|
backprop_reduce_mean_kernel_double = LazyKernel("backprop_reduce_mean<double>")
|
|
backprop_reduce_mean_kernel_float = LazyKernel("backprop_reduce_mean<float>")
|
|
backprop_reduce_sum_kernel_double = LazyKernel("backprop_reduce_sum<double>")
|
|
backprop_reduce_sum_kernel_float = LazyKernel("backprop_reduce_sum<float>")
|
|
backprop_seq2col_kernel_double = LazyKernel("backprop_seq2col<double>")
|
|
backprop_seq2col_kernel_float = LazyKernel("backprop_seq2col<float>")
|
|
backprop_swish_kernel_double = LazyKernel("backprop_swish<double>")
|
|
backprop_swish_kernel_float = LazyKernel("backprop_swish<float>")
|
|
|
|
|
|
def _alloc(shape, dtype, *, zeros: bool = True):
|
|
if zeros:
|
|
return cupy.zeros(shape, dtype)
|
|
else:
|
|
return cupy.empty(shape, dtype)
|
|
|
|
|
|
def _alloc_like(array, zeros: bool = True):
|
|
if zeros:
|
|
return cupy.zeros_like(array)
|
|
else:
|
|
return cupy.empty_like(array)
|
|
|
|
|
|
def pad(seqs, round_to=1, *, threads_per_block=128, num_blocks=128):
|
|
if round_to < 1:
|
|
raise ValueError(f"Rounding for padding must at least be 1, was: {round_to}")
|
|
for seq in seqs:
|
|
_is_float_or_int_array(seq)
|
|
|
|
seq_lens = [len(seq) for seq in seqs]
|
|
max_seq_len = max(seq_lens)
|
|
# Round the length to nearest bucket -- helps on GPU, to make similar
|
|
# array sizes.
|
|
max_seq_len += -max_seq_len % round_to
|
|
seq_lens = cupy.array(seq_lens, dtype="int32")
|
|
final_shape = (len(seqs), max_seq_len) + seqs[0].shape[1:]
|
|
out = cupy.empty(final_shape, dtype=seqs[0].dtype)
|
|
|
|
# Extract pointers from CuPy arrays, so that we can address
|
|
# them in the CUDA kernel.
|
|
ptrs = numpy.empty(
|
|
(
|
|
len(
|
|
seqs,
|
|
)
|
|
),
|
|
"int64",
|
|
)
|
|
for idx, seq in enumerate(seqs):
|
|
ptrs[idx] = seq.data.ptr
|
|
ptrs = cupy.array(ptrs)
|
|
|
|
stride = reduce(operator.mul, seqs[0].shape[1:], 1)
|
|
|
|
if out.dtype == "float32":
|
|
pad_kernel_float(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
|
|
)
|
|
elif out.dtype == "float64":
|
|
pad_kernel_double(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
|
|
)
|
|
elif out.dtype == "int32":
|
|
pad_kernel_int32(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
|
|
)
|
|
elif out.dtype == "int64":
|
|
pad_kernel_int64(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, ptrs, seq_lens, stride, len(seqs), max_seq_len),
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def clipped_linear(
|
|
X,
|
|
*,
|
|
inplace=False,
|
|
slope=1.0,
|
|
offset=0.0,
|
|
min_val=0.0,
|
|
max_val=1.0,
|
|
threads_per_block=128,
|
|
num_blocks=128,
|
|
):
|
|
_is_float_array(X)
|
|
|
|
out = X
|
|
if not inplace:
|
|
out = _alloc_like(X, zeros=False)
|
|
if X.dtype == "float32":
|
|
clipped_linear_kernel_float(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, X, slope, offset, min_val, max_val, X.size),
|
|
)
|
|
else:
|
|
clipped_linear_kernel_double(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, X, slope, offset, min_val, max_val, X.size),
|
|
)
|
|
return out
|
|
|
|
|
|
def gather_add(table, indices, *, threads_per_block=128, num_blocks=128):
|
|
if table.ndim != 2:
|
|
raise ValueError(
|
|
f"gather_add expects table with dimensionality 2, was: {table.ndim}"
|
|
)
|
|
if indices.ndim != 2:
|
|
raise ValueError(
|
|
f"gather_add expects indices with dimensionality 2, was: {indices.ndim}"
|
|
)
|
|
_is_float_array(table)
|
|
indices = indices.astype("int32")
|
|
_check_indices(indices, table.shape[0])
|
|
|
|
B = indices.shape[0]
|
|
K = indices.shape[1]
|
|
T = table.shape[0]
|
|
O = table.shape[1]
|
|
|
|
out = _alloc((B, O), dtype=table.dtype, zeros=True)
|
|
if table.dtype == "float32":
|
|
gather_add_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
|
|
)
|
|
else:
|
|
gather_add_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
|
|
)
|
|
return out
|
|
|
|
|
|
def dish(X, *, inplace=False, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(X)
|
|
|
|
out = X
|
|
if not inplace:
|
|
out = _alloc_like(X, zeros=False)
|
|
if X.dtype == "float32":
|
|
dish_kernel_float((num_blocks,), (threads_per_block,), (out, X, X.size))
|
|
else:
|
|
dish_kernel_double((num_blocks,), (threads_per_block,), (out, X, X.size))
|
|
return out
|
|
|
|
|
|
def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(X)
|
|
|
|
out = X
|
|
if not inplace:
|
|
out = _alloc_like(X, zeros=False)
|
|
if X.dtype == "float32":
|
|
gelu_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
|
|
)
|
|
else:
|
|
gelu_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
|
|
)
|
|
return out
|
|
|
|
|
|
def check_seq2col_lengths(lengths, B):
|
|
if lengths is None:
|
|
lengths = cupy.array([B], dtype="int32")
|
|
else:
|
|
_check_lengths(lengths, B)
|
|
return lengths
|
|
|
|
|
|
def seq2col(seq, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(seq)
|
|
|
|
B = seq.shape[0]
|
|
nF = nW * 2 + 1
|
|
I = seq.shape[1]
|
|
|
|
lengths = check_seq2col_lengths(lengths, B)
|
|
nL = lengths.shape[0]
|
|
|
|
out = _alloc((B, I * nF), dtype=seq.dtype, zeros=True)
|
|
|
|
if seq.size != 0 and lengths.size != 0:
|
|
if seq.dtype == "float32":
|
|
seq2col_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL)
|
|
)
|
|
else:
|
|
seq2col_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def maxout(X, *, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(X)
|
|
|
|
B, I, P = X.shape
|
|
|
|
out_shape = (B, I)
|
|
best = _alloc(out_shape, dtype=X.dtype, zeros=False)
|
|
which = _alloc(out_shape, dtype="i", zeros=False)
|
|
|
|
if X.dtype == "float32":
|
|
maxout_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (best, which, X, B, I, P)
|
|
)
|
|
else:
|
|
maxout_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (best, which, X, B, I, P)
|
|
)
|
|
|
|
return best, which
|
|
|
|
|
|
def mish(X, *, inplace=False, threshold=5, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(X)
|
|
|
|
out = X
|
|
if not inplace:
|
|
out = _alloc_like(X, zeros=False)
|
|
|
|
if X.dtype == "float32":
|
|
mish_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
|
|
)
|
|
else:
|
|
mish_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def reduce_sum(X, lengths, *, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(X)
|
|
|
|
B = len(lengths)
|
|
T = X.shape[0]
|
|
O = X.shape[1]
|
|
|
|
_check_lengths(lengths, T)
|
|
|
|
out = _alloc((B, O), dtype=X.dtype, zeros=True)
|
|
|
|
if X.dtype == "float32":
|
|
reduce_sum_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
|
|
)
|
|
else:
|
|
reduce_sum_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def reduce_mean(X, lengths, *, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(X)
|
|
|
|
B = len(lengths)
|
|
T = X.shape[0]
|
|
O = X.shape[1]
|
|
|
|
_check_lengths(lengths, T)
|
|
|
|
out = _alloc((B, O), dtype=X.dtype, zeros=True)
|
|
|
|
if X.dtype == "float32":
|
|
reduce_sum_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
|
|
)
|
|
else:
|
|
reduce_sum_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O)
|
|
)
|
|
|
|
# Avoid divide by zero
|
|
out /= lengths.reshape((-1, 1)) + 1e-10
|
|
return out
|
|
|
|
|
|
def reduce_max(X, lengths, *, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(X)
|
|
|
|
B = len(lengths)
|
|
T = X.shape[0]
|
|
O = X.shape[1]
|
|
|
|
_check_lengths(lengths, T, min_length=1)
|
|
|
|
out_shape = (B, O)
|
|
maxes = _alloc(out_shape, dtype=X.dtype, zeros=False)
|
|
which = _alloc(out_shape, dtype="i", zeros=False)
|
|
|
|
if X.dtype == "float32":
|
|
reduce_max_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (maxes, which, X, lengths, B, T, O)
|
|
)
|
|
else:
|
|
reduce_max_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (maxes, which, X, lengths, B, T, O)
|
|
)
|
|
|
|
return maxes, which
|
|
|
|
|
|
def swish(X, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(X)
|
|
|
|
out = X
|
|
if not inplace:
|
|
out = _alloc_like(X, zeros=False)
|
|
if X.dtype == "float32":
|
|
swish_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
|
|
)
|
|
else:
|
|
swish_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
|
|
)
|
|
return out
|
|
|
|
|
|
def backprop_seq2col(dY, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(dY)
|
|
|
|
B = dY.shape[0]
|
|
nF = nW * 2 + 1
|
|
I = dY.shape[1] // nF
|
|
|
|
lengths = check_seq2col_lengths(lengths, B)
|
|
nL = lengths.shape[0]
|
|
|
|
out = _alloc((B, I), dtype=dY.dtype, zeros=True)
|
|
|
|
if dY.size != 0 and lengths.size != 0:
|
|
if dY.dtype == "float32":
|
|
backprop_seq2col_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, dY, lengths, nW, B, I, nL)
|
|
)
|
|
else:
|
|
backprop_seq2col_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, dY, lengths, nW, B, I, nL)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_clipped_linear(
|
|
dY,
|
|
X,
|
|
*,
|
|
slope: float = 1.0,
|
|
offset: float = 0.0,
|
|
min_val: float = 0.0,
|
|
max_val: float = 1.0,
|
|
inplace: bool = False,
|
|
threads_per_block=128,
|
|
num_blocks=128,
|
|
):
|
|
_is_float_array(dY)
|
|
_is_float_array(X, shape=dY.shape)
|
|
|
|
out = dY
|
|
if not inplace:
|
|
out = _alloc_like(dY, zeros=False)
|
|
|
|
if dY.dtype == "float32":
|
|
backprop_clipped_linear_kernel_float(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, dY, X, slope, offset, min_val, max_val, out.size),
|
|
)
|
|
else:
|
|
backprop_clipped_linear_kernel_double(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, dY, X, slope, offset, min_val, max_val, out.size),
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_hard_swish(
|
|
dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128
|
|
):
|
|
_is_float_array(dY)
|
|
_is_float_array(X, shape=dY.shape)
|
|
|
|
out = dY
|
|
if not inplace:
|
|
out = _alloc_like(dY, zeros=False)
|
|
|
|
if dY.dtype == "float32":
|
|
backprop_hard_swish_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
|
|
)
|
|
else:
|
|
backprop_hard_swish_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_hard_swish_mobilenet(
|
|
dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128
|
|
):
|
|
_is_float_array(dY)
|
|
_is_float_array(X, shape=dY.shape)
|
|
|
|
out = dY
|
|
if not inplace:
|
|
out = _alloc_like(dY, zeros=False)
|
|
|
|
if dY.dtype == "float32":
|
|
backprop_hard_swish_mobilenet_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
|
|
)
|
|
else:
|
|
backprop_hard_swish_mobilenet_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_dish(
|
|
dY,
|
|
X,
|
|
*,
|
|
inplace: bool = False,
|
|
threads_per_block=128,
|
|
num_blocks=128,
|
|
):
|
|
_is_float_array(dY)
|
|
_is_float_array(X, shape=dY.shape)
|
|
|
|
out = dY
|
|
if not inplace:
|
|
out = _alloc_like(dY, zeros=False)
|
|
|
|
if dY.dtype == "float32":
|
|
backprop_dish_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
|
|
)
|
|
else:
|
|
backprop_dish_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_gelu(
|
|
dY,
|
|
X,
|
|
*,
|
|
inplace: bool = False,
|
|
threshold=6.0,
|
|
threads_per_block=128,
|
|
num_blocks=128,
|
|
):
|
|
_is_float_array(dY)
|
|
_is_float_array(X, shape=dY.shape)
|
|
|
|
out = dY
|
|
if not inplace:
|
|
out = _alloc_like(dY, zeros=False)
|
|
|
|
if dY.dtype == "float32":
|
|
backprop_gelu_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, out.size)
|
|
)
|
|
else:
|
|
backprop_gelu_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, out.size)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_maxout(dY, which, P, *, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(dY)
|
|
|
|
B = dY.shape[0]
|
|
I = dY.shape[1]
|
|
|
|
out = _alloc((B, I, P), dtype=dY.dtype, zeros=True)
|
|
|
|
_check_which_maxout(which, B, I, P)
|
|
|
|
if dY.dtype == "float32":
|
|
backprop_maxout_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, dY, which, B, I, P)
|
|
)
|
|
else:
|
|
backprop_maxout_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, dY, which, B, I, P)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_mish(
|
|
dY, X, *, inplace: bool = False, threshold=5, threads_per_block=128, num_blocks=128
|
|
):
|
|
_is_float_array(dY)
|
|
_is_float_array(X, shape=dY.shape)
|
|
|
|
out = dY
|
|
if not inplace:
|
|
out = _alloc_like(dY, zeros=False)
|
|
|
|
if dY.dtype == "float32":
|
|
backprop_mish_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, dY.size)
|
|
)
|
|
else:
|
|
backprop_mish_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, threshold, dY.size)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_reduce_sum(d_sums, lengths, *, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(d_sums)
|
|
|
|
B = len(lengths)
|
|
T = int(lengths.sum())
|
|
O = d_sums.shape[1]
|
|
_check_lengths(lengths, T)
|
|
|
|
out = _alloc((T, O), dtype=d_sums.dtype, zeros=False)
|
|
|
|
if d_sums.dtype == "float32":
|
|
backprop_reduce_sum_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O)
|
|
)
|
|
else:
|
|
backprop_reduce_sum_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_reduce_mean(d_means, lengths, *, threads_per_block=128, num_blocks=128):
|
|
_is_float_array(d_means)
|
|
|
|
B = len(lengths)
|
|
T = int(lengths.sum())
|
|
O = d_means.shape[1]
|
|
_check_lengths(lengths, T)
|
|
|
|
out = _alloc((T, O), dtype=d_means.dtype, zeros=False)
|
|
|
|
if d_means.dtype == "float32":
|
|
backprop_reduce_mean_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O)
|
|
)
|
|
else:
|
|
backprop_reduce_mean_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_reduce_max(
|
|
d_maxes, which, lengths, *, threads_per_block=128, num_blocks=128
|
|
):
|
|
_is_float_array(d_maxes)
|
|
|
|
B = len(lengths)
|
|
T = int(lengths.sum())
|
|
O = d_maxes.shape[1]
|
|
_check_lengths(lengths, T, min_length=1)
|
|
|
|
out = _alloc((T, O), dtype=d_maxes.dtype, zeros=True)
|
|
|
|
_check_which_reduce_max(which, (B, O), lengths)
|
|
|
|
if d_maxes.dtype == "float32":
|
|
backprop_reduce_max_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, d_maxes, which, lengths, B, T, O)
|
|
)
|
|
else:
|
|
backprop_reduce_max_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, d_maxes, which, lengths, B, T, O)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def backprop_swish(
|
|
dY, X, Y, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128
|
|
):
|
|
_is_float_array(dY)
|
|
_is_float_array(X, shape=dY.shape)
|
|
_is_float_array(Y, shape=dY.shape)
|
|
|
|
out = dY
|
|
if not inplace:
|
|
out = _alloc_like(dY, zeros=False)
|
|
|
|
if dY.dtype == "float32":
|
|
backprop_swish_kernel_float(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, Y, threshold, out.size)
|
|
)
|
|
else:
|
|
backprop_swish_kernel_double(
|
|
(num_blocks,), (threads_per_block,), (out, dY, X, Y, threshold, out.size)
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def hash(ids, seed, *, threads_per_block=128, num_blocks=128):
|
|
out = _alloc((ids.shape[0], 4), dtype="uint32", zeros=True)
|
|
|
|
# sizeof(uint32_t) * 4
|
|
out_size = 4 * 4
|
|
in_size = 8 # sizeof(uint64_t)
|
|
# T = ids.shape[0]
|
|
hash_data_kernel(
|
|
(num_blocks,),
|
|
(threads_per_block,),
|
|
(out, ids, out_size, in_size, ids.shape[0], seed),
|
|
)
|
|
return out
|
|
|
|
|
|
def _is_float_array(out, *, shape: Optional[Tuple] = None):
|
|
assert out.dtype in (
|
|
"float32",
|
|
"float64",
|
|
), "CUDA kernel can only handle float32 and float64"
|
|
if shape is not None and out.shape != shape:
|
|
msg = f"array has incorrect shape, expected: {shape}, was: {out.shape}"
|
|
raise ValueError(msg)
|
|
|
|
|
|
def _is_float_or_int_array(out, *, shape: Optional[Tuple] = None):
|
|
assert out.dtype in (
|
|
"float32",
|
|
"float64",
|
|
"int32",
|
|
"int64",
|
|
), "CUDA kernel can only handle float32, float64, int32 and int64"
|
|
if shape is not None and out.shape != shape:
|
|
msg = f"array has incorrect shape, expected: {shape}, was: {out.shape}"
|
|
raise ValueError(msg)
|
|
|
|
|
|
def _check_lengths(lengths, n_elems: int, *, min_length=0):
|
|
assert lengths.dtype == "int32", "lengths should be encoded as 32-bit integers"
|
|
if not cupy.all(lengths >= min_length):
|
|
raise ValueError(f"all sequence lengths must be >= {min_length}")
|
|
if cupy.sum(lengths) != n_elems:
|
|
raise IndexError("lengths must sum up to the batch size")
|
|
|
|
|
|
def _check_indices(indices, n: int):
|
|
assert indices.dtype == "int32", "indices should be encoded as 32-bit integers"
|
|
|
|
if not _values_within_range(indices, 0, n):
|
|
raise IndexError(f"index out of bounds, must be >= 0 && < {n}")
|
|
|
|
|
|
def _check_which_maxout(which, B: int, I: int, P: int):
|
|
shape = (B, I)
|
|
msg = "maximum index (which) should be encoded as 32-bit integers"
|
|
assert which.dtype == "int32", msg
|
|
if which.shape != shape:
|
|
msg = f"maximum index (which) has incorrect shape, expected: {shape}, was: {which.shape}"
|
|
raise ValueError(msg)
|
|
if not _values_within_range(which, 0, P):
|
|
raise IndexError("maximum index (which) value out of bounds")
|
|
|
|
|
|
_values_within_range = (
|
|
cupy.ReductionKernel(
|
|
"T x, T lower, T upper",
|
|
"bool r",
|
|
"x >= lower && x < upper",
|
|
"a & b",
|
|
"r = a",
|
|
"true",
|
|
"within_range",
|
|
)
|
|
if has_cupy_gpu
|
|
else None
|
|
)
|
|
|
|
|
|
def _check_which_reduce_max(which, shape: Tuple, lengths):
|
|
msg = "maximum index (which) should be encoded as 32-bit integers"
|
|
assert which.dtype == "int32", msg
|
|
if which.shape != shape:
|
|
msg = f"maximum index (which) has incorrect shape, expected: {shape}, was: {which.shape}"
|
|
raise ValueError(msg)
|
|
if not cupy.all((which >= 0) & (which < cupy.expand_dims(lengths, -1))):
|
|
raise IndexError("maximum index (which) value out of bounds")
|