import operator import re from collections import defaultdict from functools import reduce from pathlib import Path from typing import Callable, Optional, Tuple import numpy from ..compat import cupy, has_cupy_gpu PWD = Path(__file__).parent KERNELS_SRC = (PWD / "_custom_kernels.cu").read_text(encoding="utf8") KERNELS_LIST = [ "backprop_clipped_linear", "backprop_clipped_linear", "backprop_dish", "backprop_dish", "backprop_gelu", "backprop_gelu", "backprop_hard_swish", "backprop_hard_swish", "backprop_hard_swish_mobilenet", "backprop_hard_swish_mobilenet", "backprop_maxout", "backprop_maxout", "backprop_mish", "backprop_mish", "backprop_reduce_max", "backprop_reduce_max", "backprop_reduce_mean", "backprop_reduce_mean", "backprop_reduce_sum", "backprop_reduce_sum", "backprop_seq2col", "backprop_seq2col", "backprop_swish", "backprop_swish", "clipped_linear", "clipped_linear", "dish", "dish", "gather_add", "gather_add", "gelu", "gelu", "maxout", "maxout", "mish", "mish", "pad", "pad", "pad", "pad", "reduce_max", "reduce_max", "reduce_sum", "reduce_sum", "seq2col", "seq2col", "swish", "swish", ] KERNELS = ( cupy.RawModule( code=KERNELS_SRC, options=("--std=c++11",), name_expressions=KERNELS_LIST ) if has_cupy_gpu else None ) class LazyKernel: """Wraps around `cupy.RawModule` and `cupy.RawKernel` to verify CuPy availability and lazily compile the latter on first invocation. The default CuPy behaviour triggers the compilation as soon as the `cupy.RawKernel` object is accessed.""" name: str _kernel: Optional["cupy.RawKernel"] _compile_callback: Optional[Callable[[], "cupy.RawKernel"]] __slots__ = ["name", "_kernel", "_compile_callback"] def __init__( self, name: str, *, compile_callback: Optional[Callable[[], "cupy.RawKernel"]] = None, ) -> None: self.name = name self._kernel = None self._compile_callback = compile_callback def __call__(self, *args, **kwargs): self._compile_kernel() self._kernel(*args, **kwargs) def _compile_kernel(self): if self._kernel is not None: return if self._compile_callback is not None: self._kernel = self._compile_callback() elif KERNELS is not None: self._kernel = KERNELS.get_function(self.name) if self._kernel is None: raise ValueError(f"couldn't compile Cupy kernel '{self.name}'") def compile_mmh(): if not has_cupy_gpu: return None return cupy.RawKernel((PWD / "_murmur3.cu").read_text(encoding="utf8"), "hash_data") clipped_linear_kernel_float = LazyKernel("clipped_linear") clipped_linear_kernel_double = LazyKernel("clipped_linear") dish_kernel_float = LazyKernel("dish") dish_kernel_double = LazyKernel("dish") gather_add_kernel_float = LazyKernel("gather_add") gather_add_kernel_double = LazyKernel("gather_add") gelu_kernel_float = LazyKernel("gelu") gelu_kernel_double = LazyKernel("gelu") hash_data_kernel = LazyKernel("hash_data", compile_callback=compile_mmh) maxout_kernel_float = LazyKernel("maxout") maxout_kernel_double = LazyKernel("maxout") mish_kernel_float = LazyKernel("mish") mish_kernel_double = LazyKernel("mish") pad_kernel_float = LazyKernel("pad") pad_kernel_double = LazyKernel("pad") pad_kernel_int32 = LazyKernel("pad") pad_kernel_int64 = LazyKernel("pad") reduce_max_kernel_float = LazyKernel("reduce_max") reduce_max_kernel_double = LazyKernel("reduce_max") reduce_sum_kernel_float = LazyKernel("reduce_sum") reduce_sum_kernel_double = LazyKernel("reduce_sum") seq2col_kernel_float = LazyKernel("seq2col") seq2col_kernel_double = LazyKernel("seq2col") swish_kernel_float = LazyKernel("swish") swish_kernel_double = LazyKernel("swish") backprop_clipped_linear_kernel_double = LazyKernel("backprop_clipped_linear") backprop_clipped_linear_kernel_float = LazyKernel("backprop_clipped_linear") backprop_dish_kernel_double = LazyKernel("backprop_dish") backprop_dish_kernel_float = LazyKernel("backprop_dish") backprop_gelu_kernel_double = LazyKernel("backprop_gelu") backprop_gelu_kernel_float = LazyKernel("backprop_gelu") backprop_hard_swish_kernel_double = LazyKernel("backprop_hard_swish") backprop_hard_swish_kernel_float = LazyKernel("backprop_hard_swish") backprop_hard_swish_mobilenet_kernel_double = LazyKernel( "backprop_hard_swish_mobilenet" ) backprop_hard_swish_mobilenet_kernel_float = LazyKernel( "backprop_hard_swish_mobilenet" ) backprop_maxout_kernel_double = LazyKernel("backprop_maxout") backprop_maxout_kernel_float = LazyKernel("backprop_maxout") backprop_mish_kernel_double = LazyKernel("backprop_mish") backprop_mish_kernel_float = LazyKernel("backprop_mish") backprop_reduce_max_kernel_double = LazyKernel("backprop_reduce_max") backprop_reduce_max_kernel_float = LazyKernel("backprop_reduce_max") backprop_reduce_mean_kernel_double = LazyKernel("backprop_reduce_mean") backprop_reduce_mean_kernel_float = LazyKernel("backprop_reduce_mean") backprop_reduce_sum_kernel_double = LazyKernel("backprop_reduce_sum") backprop_reduce_sum_kernel_float = LazyKernel("backprop_reduce_sum") backprop_seq2col_kernel_double = LazyKernel("backprop_seq2col") backprop_seq2col_kernel_float = LazyKernel("backprop_seq2col") backprop_swish_kernel_double = LazyKernel("backprop_swish") backprop_swish_kernel_float = LazyKernel("backprop_swish") def _alloc(shape, dtype, *, zeros: bool = True): if zeros: return cupy.zeros(shape, dtype) else: return cupy.empty(shape, dtype) def _alloc_like(array, zeros: bool = True): if zeros: return cupy.zeros_like(array) else: return cupy.empty_like(array) def pad(seqs, round_to=1, *, threads_per_block=128, num_blocks=128): if round_to < 1: raise ValueError(f"Rounding for padding must at least be 1, was: {round_to}") for seq in seqs: _is_float_or_int_array(seq) seq_lens = [len(seq) for seq in seqs] max_seq_len = max(seq_lens) # Round the length to nearest bucket -- helps on GPU, to make similar # array sizes. max_seq_len += -max_seq_len % round_to seq_lens = cupy.array(seq_lens, dtype="int32") final_shape = (len(seqs), max_seq_len) + seqs[0].shape[1:] out = cupy.empty(final_shape, dtype=seqs[0].dtype) # Extract pointers from CuPy arrays, so that we can address # them in the CUDA kernel. ptrs = numpy.empty( ( len( seqs, ) ), "int64", ) for idx, seq in enumerate(seqs): ptrs[idx] = seq.data.ptr ptrs = cupy.array(ptrs) stride = reduce(operator.mul, seqs[0].shape[1:], 1) if out.dtype == "float32": pad_kernel_float( (num_blocks,), (threads_per_block,), (out, ptrs, seq_lens, stride, len(seqs), max_seq_len), ) elif out.dtype == "float64": pad_kernel_double( (num_blocks,), (threads_per_block,), (out, ptrs, seq_lens, stride, len(seqs), max_seq_len), ) elif out.dtype == "int32": pad_kernel_int32( (num_blocks,), (threads_per_block,), (out, ptrs, seq_lens, stride, len(seqs), max_seq_len), ) elif out.dtype == "int64": pad_kernel_int64( (num_blocks,), (threads_per_block,), (out, ptrs, seq_lens, stride, len(seqs), max_seq_len), ) return out def clipped_linear( X, *, inplace=False, slope=1.0, offset=0.0, min_val=0.0, max_val=1.0, threads_per_block=128, num_blocks=128, ): _is_float_array(X) out = X if not inplace: out = _alloc_like(X, zeros=False) if X.dtype == "float32": clipped_linear_kernel_float( (num_blocks,), (threads_per_block,), (out, X, slope, offset, min_val, max_val, X.size), ) else: clipped_linear_kernel_double( (num_blocks,), (threads_per_block,), (out, X, slope, offset, min_val, max_val, X.size), ) return out def gather_add(table, indices, *, threads_per_block=128, num_blocks=128): if table.ndim != 2: raise ValueError( f"gather_add expects table with dimensionality 2, was: {table.ndim}" ) if indices.ndim != 2: raise ValueError( f"gather_add expects indices with dimensionality 2, was: {indices.ndim}" ) _is_float_array(table) indices = indices.astype("int32") _check_indices(indices, table.shape[0]) B = indices.shape[0] K = indices.shape[1] T = table.shape[0] O = table.shape[1] out = _alloc((B, O), dtype=table.dtype, zeros=True) if table.dtype == "float32": gather_add_kernel_float( (num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K) ) else: gather_add_kernel_double( (num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K) ) return out def dish(X, *, inplace=False, threads_per_block=128, num_blocks=128): _is_float_array(X) out = X if not inplace: out = _alloc_like(X, zeros=False) if X.dtype == "float32": dish_kernel_float((num_blocks,), (threads_per_block,), (out, X, X.size)) else: dish_kernel_double((num_blocks,), (threads_per_block,), (out, X, X.size)) return out def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128): _is_float_array(X) out = X if not inplace: out = _alloc_like(X, zeros=False) if X.dtype == "float32": gelu_kernel_float( (num_blocks,), (threads_per_block,), (out, X, threshold, X.size) ) else: gelu_kernel_double( (num_blocks,), (threads_per_block,), (out, X, threshold, X.size) ) return out def check_seq2col_lengths(lengths, B): if lengths is None: lengths = cupy.array([B], dtype="int32") else: _check_lengths(lengths, B) return lengths def seq2col(seq, nW, *, lengths=None, threads_per_block=128, num_blocks=128): _is_float_array(seq) B = seq.shape[0] nF = nW * 2 + 1 I = seq.shape[1] lengths = check_seq2col_lengths(lengths, B) nL = lengths.shape[0] out = _alloc((B, I * nF), dtype=seq.dtype, zeros=True) if seq.size != 0 and lengths.size != 0: if seq.dtype == "float32": seq2col_kernel_float( (num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL) ) else: seq2col_kernel_double( (num_blocks,), (threads_per_block,), (out, seq, lengths, nW, B, I, nL) ) return out def maxout(X, *, threads_per_block=128, num_blocks=128): _is_float_array(X) B, I, P = X.shape out_shape = (B, I) best = _alloc(out_shape, dtype=X.dtype, zeros=False) which = _alloc(out_shape, dtype="i", zeros=False) if X.dtype == "float32": maxout_kernel_float( (num_blocks,), (threads_per_block,), (best, which, X, B, I, P) ) else: maxout_kernel_double( (num_blocks,), (threads_per_block,), (best, which, X, B, I, P) ) return best, which def mish(X, *, inplace=False, threshold=5, threads_per_block=128, num_blocks=128): _is_float_array(X) out = X if not inplace: out = _alloc_like(X, zeros=False) if X.dtype == "float32": mish_kernel_float( (num_blocks,), (threads_per_block,), (out, X, threshold, X.size) ) else: mish_kernel_double( (num_blocks,), (threads_per_block,), (out, X, threshold, X.size) ) return out def reduce_sum(X, lengths, *, threads_per_block=128, num_blocks=128): _is_float_array(X) B = len(lengths) T = X.shape[0] O = X.shape[1] _check_lengths(lengths, T) out = _alloc((B, O), dtype=X.dtype, zeros=True) if X.dtype == "float32": reduce_sum_kernel_float( (num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O) ) else: reduce_sum_kernel_double( (num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O) ) return out def reduce_mean(X, lengths, *, threads_per_block=128, num_blocks=128): _is_float_array(X) B = len(lengths) T = X.shape[0] O = X.shape[1] _check_lengths(lengths, T) out = _alloc((B, O), dtype=X.dtype, zeros=True) if X.dtype == "float32": reduce_sum_kernel_float( (num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O) ) else: reduce_sum_kernel_double( (num_blocks,), (threads_per_block,), (out, X, lengths, B, T, O) ) # Avoid divide by zero out /= lengths.reshape((-1, 1)) + 1e-10 return out def reduce_max(X, lengths, *, threads_per_block=128, num_blocks=128): _is_float_array(X) B = len(lengths) T = X.shape[0] O = X.shape[1] _check_lengths(lengths, T, min_length=1) out_shape = (B, O) maxes = _alloc(out_shape, dtype=X.dtype, zeros=False) which = _alloc(out_shape, dtype="i", zeros=False) if X.dtype == "float32": reduce_max_kernel_float( (num_blocks,), (threads_per_block,), (maxes, which, X, lengths, B, T, O) ) else: reduce_max_kernel_double( (num_blocks,), (threads_per_block,), (maxes, which, X, lengths, B, T, O) ) return maxes, which def swish(X, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128): _is_float_array(X) out = X if not inplace: out = _alloc_like(X, zeros=False) if X.dtype == "float32": swish_kernel_float( (num_blocks,), (threads_per_block,), (out, X, threshold, X.size) ) else: swish_kernel_double( (num_blocks,), (threads_per_block,), (out, X, threshold, X.size) ) return out def backprop_seq2col(dY, nW, *, lengths=None, threads_per_block=128, num_blocks=128): _is_float_array(dY) B = dY.shape[0] nF = nW * 2 + 1 I = dY.shape[1] // nF lengths = check_seq2col_lengths(lengths, B) nL = lengths.shape[0] out = _alloc((B, I), dtype=dY.dtype, zeros=True) if dY.size != 0 and lengths.size != 0: if dY.dtype == "float32": backprop_seq2col_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, lengths, nW, B, I, nL) ) else: backprop_seq2col_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, lengths, nW, B, I, nL) ) return out def backprop_clipped_linear( dY, X, *, slope: float = 1.0, offset: float = 0.0, min_val: float = 0.0, max_val: float = 1.0, inplace: bool = False, threads_per_block=128, num_blocks=128, ): _is_float_array(dY) _is_float_array(X, shape=dY.shape) out = dY if not inplace: out = _alloc_like(dY, zeros=False) if dY.dtype == "float32": backprop_clipped_linear_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, X, slope, offset, min_val, max_val, out.size), ) else: backprop_clipped_linear_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, X, slope, offset, min_val, max_val, out.size), ) return out def backprop_hard_swish( dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128 ): _is_float_array(dY) _is_float_array(X, shape=dY.shape) out = dY if not inplace: out = _alloc_like(dY, zeros=False) if dY.dtype == "float32": backprop_hard_swish_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, X, out.size) ) else: backprop_hard_swish_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, X, out.size) ) return out def backprop_hard_swish_mobilenet( dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128 ): _is_float_array(dY) _is_float_array(X, shape=dY.shape) out = dY if not inplace: out = _alloc_like(dY, zeros=False) if dY.dtype == "float32": backprop_hard_swish_mobilenet_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, X, out.size) ) else: backprop_hard_swish_mobilenet_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, X, out.size) ) return out def backprop_dish( dY, X, *, inplace: bool = False, threads_per_block=128, num_blocks=128, ): _is_float_array(dY) _is_float_array(X, shape=dY.shape) out = dY if not inplace: out = _alloc_like(dY, zeros=False) if dY.dtype == "float32": backprop_dish_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, X, out.size) ) else: backprop_dish_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, X, out.size) ) return out def backprop_gelu( dY, X, *, inplace: bool = False, threshold=6.0, threads_per_block=128, num_blocks=128, ): _is_float_array(dY) _is_float_array(X, shape=dY.shape) out = dY if not inplace: out = _alloc_like(dY, zeros=False) if dY.dtype == "float32": backprop_gelu_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, X, threshold, out.size) ) else: backprop_gelu_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, X, threshold, out.size) ) return out def backprop_maxout(dY, which, P, *, threads_per_block=128, num_blocks=128): _is_float_array(dY) B = dY.shape[0] I = dY.shape[1] out = _alloc((B, I, P), dtype=dY.dtype, zeros=True) _check_which_maxout(which, B, I, P) if dY.dtype == "float32": backprop_maxout_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, which, B, I, P) ) else: backprop_maxout_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, which, B, I, P) ) return out def backprop_mish( dY, X, *, inplace: bool = False, threshold=5, threads_per_block=128, num_blocks=128 ): _is_float_array(dY) _is_float_array(X, shape=dY.shape) out = dY if not inplace: out = _alloc_like(dY, zeros=False) if dY.dtype == "float32": backprop_mish_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, X, threshold, dY.size) ) else: backprop_mish_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, X, threshold, dY.size) ) return out def backprop_reduce_sum(d_sums, lengths, *, threads_per_block=128, num_blocks=128): _is_float_array(d_sums) B = len(lengths) T = int(lengths.sum()) O = d_sums.shape[1] _check_lengths(lengths, T) out = _alloc((T, O), dtype=d_sums.dtype, zeros=False) if d_sums.dtype == "float32": backprop_reduce_sum_kernel_float( (num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O) ) else: backprop_reduce_sum_kernel_double( (num_blocks,), (threads_per_block,), (out, d_sums, lengths, B, T, O) ) return out def backprop_reduce_mean(d_means, lengths, *, threads_per_block=128, num_blocks=128): _is_float_array(d_means) B = len(lengths) T = int(lengths.sum()) O = d_means.shape[1] _check_lengths(lengths, T) out = _alloc((T, O), dtype=d_means.dtype, zeros=False) if d_means.dtype == "float32": backprop_reduce_mean_kernel_float( (num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O) ) else: backprop_reduce_mean_kernel_double( (num_blocks,), (threads_per_block,), (out, d_means, lengths, B, T, O) ) return out def backprop_reduce_max( d_maxes, which, lengths, *, threads_per_block=128, num_blocks=128 ): _is_float_array(d_maxes) B = len(lengths) T = int(lengths.sum()) O = d_maxes.shape[1] _check_lengths(lengths, T, min_length=1) out = _alloc((T, O), dtype=d_maxes.dtype, zeros=True) _check_which_reduce_max(which, (B, O), lengths) if d_maxes.dtype == "float32": backprop_reduce_max_kernel_float( (num_blocks,), (threads_per_block,), (out, d_maxes, which, lengths, B, T, O) ) else: backprop_reduce_max_kernel_double( (num_blocks,), (threads_per_block,), (out, d_maxes, which, lengths, B, T, O) ) return out def backprop_swish( dY, X, Y, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks=128 ): _is_float_array(dY) _is_float_array(X, shape=dY.shape) _is_float_array(Y, shape=dY.shape) out = dY if not inplace: out = _alloc_like(dY, zeros=False) if dY.dtype == "float32": backprop_swish_kernel_float( (num_blocks,), (threads_per_block,), (out, dY, X, Y, threshold, out.size) ) else: backprop_swish_kernel_double( (num_blocks,), (threads_per_block,), (out, dY, X, Y, threshold, out.size) ) return out def hash(ids, seed, *, threads_per_block=128, num_blocks=128): out = _alloc((ids.shape[0], 4), dtype="uint32", zeros=True) # sizeof(uint32_t) * 4 out_size = 4 * 4 in_size = 8 # sizeof(uint64_t) # T = ids.shape[0] hash_data_kernel( (num_blocks,), (threads_per_block,), (out, ids, out_size, in_size, ids.shape[0], seed), ) return out def _is_float_array(out, *, shape: Optional[Tuple] = None): assert out.dtype in ( "float32", "float64", ), "CUDA kernel can only handle float32 and float64" if shape is not None and out.shape != shape: msg = f"array has incorrect shape, expected: {shape}, was: {out.shape}" raise ValueError(msg) def _is_float_or_int_array(out, *, shape: Optional[Tuple] = None): assert out.dtype in ( "float32", "float64", "int32", "int64", ), "CUDA kernel can only handle float32, float64, int32 and int64" if shape is not None and out.shape != shape: msg = f"array has incorrect shape, expected: {shape}, was: {out.shape}" raise ValueError(msg) def _check_lengths(lengths, n_elems: int, *, min_length=0): assert lengths.dtype == "int32", "lengths should be encoded as 32-bit integers" if not cupy.all(lengths >= min_length): raise ValueError(f"all sequence lengths must be >= {min_length}") if cupy.sum(lengths) != n_elems: raise IndexError("lengths must sum up to the batch size") def _check_indices(indices, n: int): assert indices.dtype == "int32", "indices should be encoded as 32-bit integers" if not _values_within_range(indices, 0, n): raise IndexError(f"index out of bounds, must be >= 0 && < {n}") def _check_which_maxout(which, B: int, I: int, P: int): shape = (B, I) msg = "maximum index (which) should be encoded as 32-bit integers" assert which.dtype == "int32", msg if which.shape != shape: msg = f"maximum index (which) has incorrect shape, expected: {shape}, was: {which.shape}" raise ValueError(msg) if not _values_within_range(which, 0, P): raise IndexError("maximum index (which) value out of bounds") _values_within_range = ( cupy.ReductionKernel( "T x, T lower, T upper", "bool r", "x >= lower && x < upper", "a & b", "r = a", "true", "within_range", ) if has_cupy_gpu else None ) def _check_which_reduce_max(which, shape: Tuple, lengths): msg = "maximum index (which) should be encoded as 32-bit integers" assert which.dtype == "int32", msg if which.shape != shape: msg = f"maximum index (which) has incorrect shape, expected: {shape}, was: {which.shape}" raise ValueError(msg) if not cupy.all((which >= 0) & (which < cupy.expand_dims(lengths, -1))): raise IndexError("maximum index (which) value out of bounds")