ai-content-maker/.venv/Lib/site-packages/numba/cuda/tests/cudapy/test_fastmath.py

241 lines
8.0 KiB
Python

from typing import List
from dataclasses import dataclass, field
from numba import cuda, float32
from numba.cuda.compiler import compile_ptx_for_current_device, compile_ptx
from math import cos, sin, tan, exp, log, log10, log2, pow, tanh
from operator import truediv
import numpy as np
from numba.cuda.testing import (CUDATestCase, skip_on_cudasim,
skip_unless_cc_75)
import unittest
@dataclass
class FastMathCriterion:
fast_expected: List[str] = field(default_factory=list)
fast_unexpected: List[str] = field(default_factory=list)
prec_expected: List[str] = field(default_factory=list)
prec_unexpected: List[str] = field(default_factory=list)
def check(self, test: CUDATestCase, fast: str, prec: str):
test.assertTrue(all(i in fast for i in self.fast_expected))
test.assertTrue(all(i not in fast for i in self.fast_unexpected))
test.assertTrue(all(i in prec for i in self.prec_expected))
test.assertTrue(all(i not in prec for i in self.prec_unexpected))
@skip_on_cudasim('Fastmath and PTX inspection not available on cudasim')
class TestFastMathOption(CUDATestCase):
def _test_fast_math_common(self, pyfunc, sig, device, criterion):
# Test jit code path
fastver = cuda.jit(sig, device=device, fastmath=True)(pyfunc)
precver = cuda.jit(sig, device=device)(pyfunc)
criterion.check(
self, fastver.inspect_asm(sig), precver.inspect_asm(sig)
)
# Test compile_ptx code path
fastptx, _ = compile_ptx_for_current_device(
pyfunc, sig, device=device, fastmath=True
)
precptx, _ = compile_ptx_for_current_device(
pyfunc, sig, device=device
)
criterion.check(self, fastptx, precptx)
def _test_fast_math_unary(self, op, criterion: FastMathCriterion):
def kernel(r, x):
r[0] = op(x)
def device_function(x):
return op(x)
self._test_fast_math_common(
kernel, (float32[::1], float32), device=False, criterion=criterion
)
self._test_fast_math_common(
device_function, (float32,), device=True, criterion=criterion
)
def _test_fast_math_binary(self, op, criterion: FastMathCriterion):
def kernel(r, x, y):
r[0] = op(x, y)
def device(x, y):
return op(x, y)
self._test_fast_math_common(
kernel,
(float32[::1], float32, float32), device=False, criterion=criterion
)
self._test_fast_math_common(
device, (float32, float32), device=True, criterion=criterion
)
def test_cosf(self):
self._test_fast_math_unary(
cos,
FastMathCriterion(
fast_expected=['cos.approx.ftz.f32 '],
prec_unexpected=['cos.approx.ftz.f32 ']
)
)
def test_sinf(self):
self._test_fast_math_unary(
sin,
FastMathCriterion(
fast_expected=['sin.approx.ftz.f32 '],
prec_unexpected=['sin.approx.ftz.f32 ']
)
)
def test_tanf(self):
self._test_fast_math_unary(
tan,
FastMathCriterion(fast_expected=[
'sin.approx.ftz.f32 ',
'cos.approx.ftz.f32 ',
'div.approx.ftz.f32 '
], prec_unexpected=['sin.approx.ftz.f32 '])
)
@skip_unless_cc_75
def test_tanhf(self):
self._test_fast_math_unary(
tanh,
FastMathCriterion(
fast_expected=['tanh.approx.f32 '],
prec_unexpected=['tanh.approx.f32 ']
)
)
def test_tanhf_compile_ptx(self):
def tanh_kernel(r, x):
r[0] = tanh(x)
def tanh_common_test(cc, criterion):
fastptx, _ = compile_ptx(tanh_kernel, (float32[::1], float32),
fastmath=True, cc=cc)
precptx, _ = compile_ptx(tanh_kernel, (float32[::1], float32),
cc=cc)
criterion.check(self, fastptx, precptx)
tanh_common_test(cc=(7, 5), criterion=FastMathCriterion(
fast_expected=['tanh.approx.f32 '],
prec_unexpected=['tanh.approx.f32 ']
))
tanh_common_test(cc=(7, 0),
criterion=FastMathCriterion(
fast_expected=['ex2.approx.ftz.f32 ',
'rcp.approx.ftz.f32 '],
prec_unexpected=['tanh.approx.f32 ']))
def test_expf(self):
self._test_fast_math_unary(
exp,
FastMathCriterion(
fast_unexpected=['fma.rn.f32 '],
prec_expected=['fma.rn.f32 ']
)
)
def test_logf(self):
# Look for constant used to convert from log base 2 to log base e
self._test_fast_math_unary(
log, FastMathCriterion(
fast_expected=['lg2.approx.ftz.f32 ', '0f3F317218'],
prec_unexpected=['lg2.approx.ftz.f32 '],
)
)
def test_log10f(self):
# Look for constant used to convert from log base 2 to log base 10
self._test_fast_math_unary(
log10, FastMathCriterion(
fast_expected=['lg2.approx.ftz.f32 ', '0f3E9A209B'],
prec_unexpected=['lg2.approx.ftz.f32 ']
)
)
def test_log2f(self):
self._test_fast_math_unary(
log2, FastMathCriterion(
fast_expected=['lg2.approx.ftz.f32 '],
prec_unexpected=['lg2.approx.ftz.f32 ']
)
)
def test_powf(self):
self._test_fast_math_binary(
pow, FastMathCriterion(
fast_expected=['lg2.approx.ftz.f32 '],
prec_unexpected=['lg2.approx.ftz.f32 '],
)
)
def test_divf(self):
self._test_fast_math_binary(
truediv, FastMathCriterion(
fast_expected=['div.approx.ftz.f32 '],
fast_unexpected=['div.rn.f32'],
prec_expected=['div.rn.f32'],
prec_unexpected=['div.approx.ftz.f32 '],
)
)
def test_divf_exception(self):
def f10(r, x, y):
r[0] = x / y
sig = (float32[::1], float32, float32)
fastver = cuda.jit(sig, fastmath=True, debug=True)(f10)
precver = cuda.jit(sig, debug=True)(f10)
nelem = 10
ary = np.empty(nelem, dtype=np.float32)
with self.assertRaises(ZeroDivisionError):
precver[1, nelem](ary, 10.0, 0.0)
try:
fastver[1, nelem](ary, 10.0, 0.0)
except ZeroDivisionError:
self.fail("Divide in fastmath should not throw ZeroDivisionError")
@unittest.expectedFailure
def test_device_fastmath_propagation(self):
# The fastmath option doesn't presently propagate to device functions
# from their callees - arguably it should do, so this test is presently
# an xfail.
@cuda.jit("float32(float32, float32)", device=True)
def foo(a, b):
return a / b
def bar(arr, val):
i = cuda.grid(1)
if i < arr.size:
arr[i] = foo(i, val)
sig = (float32[::1], float32)
fastver = cuda.jit(sig, fastmath=True)(bar)
precver = cuda.jit(sig)(bar)
# Variants of the div instruction are further documented at:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
# The fast version should use the "fast, approximate divide" variant
self.assertIn('div.approx.f32', fastver.inspect_asm(sig))
# The precise version should use the "IEEE 754 compliant rounding"
# variant, and neither of the "approximate divide" variants.
self.assertIn('div.rn.f32', precver.inspect_asm(sig))
self.assertNotIn('div.approx.f32', precver.inspect_asm(sig))
self.assertNotIn('div.full.f32', precver.inspect_asm(sig))
if __name__ == '__main__':
unittest.main()