241 lines
8.0 KiB
Python
241 lines
8.0 KiB
Python
|
from typing import List
|
||
|
from dataclasses import dataclass, field
|
||
|
from numba import cuda, float32
|
||
|
from numba.cuda.compiler import compile_ptx_for_current_device, compile_ptx
|
||
|
from math import cos, sin, tan, exp, log, log10, log2, pow, tanh
|
||
|
from operator import truediv
|
||
|
import numpy as np
|
||
|
from numba.cuda.testing import (CUDATestCase, skip_on_cudasim,
|
||
|
skip_unless_cc_75)
|
||
|
import unittest
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class FastMathCriterion:
|
||
|
fast_expected: List[str] = field(default_factory=list)
|
||
|
fast_unexpected: List[str] = field(default_factory=list)
|
||
|
prec_expected: List[str] = field(default_factory=list)
|
||
|
prec_unexpected: List[str] = field(default_factory=list)
|
||
|
|
||
|
def check(self, test: CUDATestCase, fast: str, prec: str):
|
||
|
test.assertTrue(all(i in fast for i in self.fast_expected))
|
||
|
test.assertTrue(all(i not in fast for i in self.fast_unexpected))
|
||
|
test.assertTrue(all(i in prec for i in self.prec_expected))
|
||
|
test.assertTrue(all(i not in prec for i in self.prec_unexpected))
|
||
|
|
||
|
|
||
|
@skip_on_cudasim('Fastmath and PTX inspection not available on cudasim')
|
||
|
class TestFastMathOption(CUDATestCase):
|
||
|
def _test_fast_math_common(self, pyfunc, sig, device, criterion):
|
||
|
|
||
|
# Test jit code path
|
||
|
fastver = cuda.jit(sig, device=device, fastmath=True)(pyfunc)
|
||
|
precver = cuda.jit(sig, device=device)(pyfunc)
|
||
|
|
||
|
criterion.check(
|
||
|
self, fastver.inspect_asm(sig), precver.inspect_asm(sig)
|
||
|
)
|
||
|
|
||
|
# Test compile_ptx code path
|
||
|
fastptx, _ = compile_ptx_for_current_device(
|
||
|
pyfunc, sig, device=device, fastmath=True
|
||
|
)
|
||
|
precptx, _ = compile_ptx_for_current_device(
|
||
|
pyfunc, sig, device=device
|
||
|
)
|
||
|
|
||
|
criterion.check(self, fastptx, precptx)
|
||
|
|
||
|
def _test_fast_math_unary(self, op, criterion: FastMathCriterion):
|
||
|
def kernel(r, x):
|
||
|
r[0] = op(x)
|
||
|
|
||
|
def device_function(x):
|
||
|
return op(x)
|
||
|
|
||
|
self._test_fast_math_common(
|
||
|
kernel, (float32[::1], float32), device=False, criterion=criterion
|
||
|
)
|
||
|
self._test_fast_math_common(
|
||
|
device_function, (float32,), device=True, criterion=criterion
|
||
|
)
|
||
|
|
||
|
def _test_fast_math_binary(self, op, criterion: FastMathCriterion):
|
||
|
def kernel(r, x, y):
|
||
|
r[0] = op(x, y)
|
||
|
|
||
|
def device(x, y):
|
||
|
return op(x, y)
|
||
|
|
||
|
self._test_fast_math_common(
|
||
|
kernel,
|
||
|
(float32[::1], float32, float32), device=False, criterion=criterion
|
||
|
)
|
||
|
self._test_fast_math_common(
|
||
|
device, (float32, float32), device=True, criterion=criterion
|
||
|
)
|
||
|
|
||
|
def test_cosf(self):
|
||
|
self._test_fast_math_unary(
|
||
|
cos,
|
||
|
FastMathCriterion(
|
||
|
fast_expected=['cos.approx.ftz.f32 '],
|
||
|
prec_unexpected=['cos.approx.ftz.f32 ']
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_sinf(self):
|
||
|
self._test_fast_math_unary(
|
||
|
sin,
|
||
|
FastMathCriterion(
|
||
|
fast_expected=['sin.approx.ftz.f32 '],
|
||
|
prec_unexpected=['sin.approx.ftz.f32 ']
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_tanf(self):
|
||
|
self._test_fast_math_unary(
|
||
|
tan,
|
||
|
FastMathCriterion(fast_expected=[
|
||
|
'sin.approx.ftz.f32 ',
|
||
|
'cos.approx.ftz.f32 ',
|
||
|
'div.approx.ftz.f32 '
|
||
|
], prec_unexpected=['sin.approx.ftz.f32 '])
|
||
|
)
|
||
|
|
||
|
@skip_unless_cc_75
|
||
|
def test_tanhf(self):
|
||
|
|
||
|
self._test_fast_math_unary(
|
||
|
tanh,
|
||
|
FastMathCriterion(
|
||
|
fast_expected=['tanh.approx.f32 '],
|
||
|
prec_unexpected=['tanh.approx.f32 ']
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_tanhf_compile_ptx(self):
|
||
|
def tanh_kernel(r, x):
|
||
|
r[0] = tanh(x)
|
||
|
|
||
|
def tanh_common_test(cc, criterion):
|
||
|
fastptx, _ = compile_ptx(tanh_kernel, (float32[::1], float32),
|
||
|
fastmath=True, cc=cc)
|
||
|
precptx, _ = compile_ptx(tanh_kernel, (float32[::1], float32),
|
||
|
cc=cc)
|
||
|
criterion.check(self, fastptx, precptx)
|
||
|
|
||
|
tanh_common_test(cc=(7, 5), criterion=FastMathCriterion(
|
||
|
fast_expected=['tanh.approx.f32 '],
|
||
|
prec_unexpected=['tanh.approx.f32 ']
|
||
|
))
|
||
|
|
||
|
tanh_common_test(cc=(7, 0),
|
||
|
criterion=FastMathCriterion(
|
||
|
fast_expected=['ex2.approx.ftz.f32 ',
|
||
|
'rcp.approx.ftz.f32 '],
|
||
|
prec_unexpected=['tanh.approx.f32 ']))
|
||
|
|
||
|
def test_expf(self):
|
||
|
self._test_fast_math_unary(
|
||
|
exp,
|
||
|
FastMathCriterion(
|
||
|
fast_unexpected=['fma.rn.f32 '],
|
||
|
prec_expected=['fma.rn.f32 ']
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_logf(self):
|
||
|
# Look for constant used to convert from log base 2 to log base e
|
||
|
self._test_fast_math_unary(
|
||
|
log, FastMathCriterion(
|
||
|
fast_expected=['lg2.approx.ftz.f32 ', '0f3F317218'],
|
||
|
prec_unexpected=['lg2.approx.ftz.f32 '],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_log10f(self):
|
||
|
# Look for constant used to convert from log base 2 to log base 10
|
||
|
self._test_fast_math_unary(
|
||
|
log10, FastMathCriterion(
|
||
|
fast_expected=['lg2.approx.ftz.f32 ', '0f3E9A209B'],
|
||
|
prec_unexpected=['lg2.approx.ftz.f32 ']
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_log2f(self):
|
||
|
self._test_fast_math_unary(
|
||
|
log2, FastMathCriterion(
|
||
|
fast_expected=['lg2.approx.ftz.f32 '],
|
||
|
prec_unexpected=['lg2.approx.ftz.f32 ']
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_powf(self):
|
||
|
self._test_fast_math_binary(
|
||
|
pow, FastMathCriterion(
|
||
|
fast_expected=['lg2.approx.ftz.f32 '],
|
||
|
prec_unexpected=['lg2.approx.ftz.f32 '],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_divf(self):
|
||
|
self._test_fast_math_binary(
|
||
|
truediv, FastMathCriterion(
|
||
|
fast_expected=['div.approx.ftz.f32 '],
|
||
|
fast_unexpected=['div.rn.f32'],
|
||
|
prec_expected=['div.rn.f32'],
|
||
|
prec_unexpected=['div.approx.ftz.f32 '],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def test_divf_exception(self):
|
||
|
def f10(r, x, y):
|
||
|
r[0] = x / y
|
||
|
|
||
|
sig = (float32[::1], float32, float32)
|
||
|
fastver = cuda.jit(sig, fastmath=True, debug=True)(f10)
|
||
|
precver = cuda.jit(sig, debug=True)(f10)
|
||
|
nelem = 10
|
||
|
ary = np.empty(nelem, dtype=np.float32)
|
||
|
with self.assertRaises(ZeroDivisionError):
|
||
|
precver[1, nelem](ary, 10.0, 0.0)
|
||
|
|
||
|
try:
|
||
|
fastver[1, nelem](ary, 10.0, 0.0)
|
||
|
except ZeroDivisionError:
|
||
|
self.fail("Divide in fastmath should not throw ZeroDivisionError")
|
||
|
|
||
|
@unittest.expectedFailure
|
||
|
def test_device_fastmath_propagation(self):
|
||
|
# The fastmath option doesn't presently propagate to device functions
|
||
|
# from their callees - arguably it should do, so this test is presently
|
||
|
# an xfail.
|
||
|
@cuda.jit("float32(float32, float32)", device=True)
|
||
|
def foo(a, b):
|
||
|
return a / b
|
||
|
|
||
|
def bar(arr, val):
|
||
|
i = cuda.grid(1)
|
||
|
if i < arr.size:
|
||
|
arr[i] = foo(i, val)
|
||
|
|
||
|
sig = (float32[::1], float32)
|
||
|
fastver = cuda.jit(sig, fastmath=True)(bar)
|
||
|
precver = cuda.jit(sig)(bar)
|
||
|
|
||
|
# Variants of the div instruction are further documented at:
|
||
|
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
|
||
|
|
||
|
# The fast version should use the "fast, approximate divide" variant
|
||
|
self.assertIn('div.approx.f32', fastver.inspect_asm(sig))
|
||
|
# The precise version should use the "IEEE 754 compliant rounding"
|
||
|
# variant, and neither of the "approximate divide" variants.
|
||
|
self.assertIn('div.rn.f32', precver.inspect_asm(sig))
|
||
|
self.assertNotIn('div.approx.f32', precver.inspect_asm(sig))
|
||
|
self.assertNotIn('div.full.f32', precver.inspect_asm(sig))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|