ai-content-maker/.venv/Lib/site-packages/numba/cuda/tests/cudapy/test_overload.py

301 lines
6.9 KiB
Python

from numba import cuda, njit
from numba.core.extending import overload
from numba.cuda.testing import CUDATestCase, skip_on_cudasim, unittest
import numpy as np
# Dummy function definitions to overload
def generic_func_1():
pass
def cuda_func_1():
pass
def generic_func_2():
pass
def cuda_func_2():
pass
def generic_calls_generic():
pass
def generic_calls_cuda():
pass
def cuda_calls_generic():
pass
def cuda_calls_cuda():
pass
def target_overloaded():
pass
def generic_calls_target_overloaded():
pass
def cuda_calls_target_overloaded():
pass
def target_overloaded_calls_target_overloaded():
pass
# To recognise which functions are resolved for a call, we identify each with a
# prime number. Each function called multiplies a value by its prime (starting
# with the value 1), and we can check that the result is as expected based on
# the final value after all multiplications.
GENERIC_FUNCTION_1 = 2
CUDA_FUNCTION_1 = 3
GENERIC_FUNCTION_2 = 5
CUDA_FUNCTION_2 = 7
GENERIC_CALLS_GENERIC = 11
GENERIC_CALLS_CUDA = 13
CUDA_CALLS_GENERIC = 17
CUDA_CALLS_CUDA = 19
GENERIC_TARGET_OL = 23
CUDA_TARGET_OL = 29
GENERIC_CALLS_TARGET_OL = 31
CUDA_CALLS_TARGET_OL = 37
GENERIC_TARGET_OL_CALLS_TARGET_OL = 41
CUDA_TARGET_OL_CALLS_TARGET_OL = 43
# Overload implementations
@overload(generic_func_1, target='generic')
def ol_generic_func_1(x):
def impl(x):
x[0] *= GENERIC_FUNCTION_1
return impl
@overload(cuda_func_1, target='cuda')
def ol_cuda_func_1(x):
def impl(x):
x[0] *= CUDA_FUNCTION_1
return impl
@overload(generic_func_2, target='generic')
def ol_generic_func_2(x):
def impl(x):
x[0] *= GENERIC_FUNCTION_2
return impl
@overload(cuda_func_2, target='cuda')
def ol_cuda_func(x):
def impl(x):
x[0] *= CUDA_FUNCTION_2
return impl
@overload(generic_calls_generic, target='generic')
def ol_generic_calls_generic(x):
def impl(x):
x[0] *= GENERIC_CALLS_GENERIC
generic_func_1(x)
return impl
@overload(generic_calls_cuda, target='generic')
def ol_generic_calls_cuda(x):
def impl(x):
x[0] *= GENERIC_CALLS_CUDA
cuda_func_1(x)
return impl
@overload(cuda_calls_generic, target='cuda')
def ol_cuda_calls_generic(x):
def impl(x):
x[0] *= CUDA_CALLS_GENERIC
generic_func_1(x)
return impl
@overload(cuda_calls_cuda, target='cuda')
def ol_cuda_calls_cuda(x):
def impl(x):
x[0] *= CUDA_CALLS_CUDA
cuda_func_1(x)
return impl
@overload(target_overloaded, target='generic')
def ol_target_overloaded_generic(x):
def impl(x):
x[0] *= GENERIC_TARGET_OL
return impl
@overload(target_overloaded, target='cuda')
def ol_target_overloaded_cuda(x):
def impl(x):
x[0] *= CUDA_TARGET_OL
return impl
@overload(generic_calls_target_overloaded, target='generic')
def ol_generic_calls_target_overloaded(x):
def impl(x):
x[0] *= GENERIC_CALLS_TARGET_OL
target_overloaded(x)
return impl
@overload(cuda_calls_target_overloaded, target='cuda')
def ol_cuda_calls_target_overloaded(x):
def impl(x):
x[0] *= CUDA_CALLS_TARGET_OL
target_overloaded(x)
return impl
@overload(target_overloaded_calls_target_overloaded, target='generic')
def ol_generic_calls_target_overloaded_generic(x):
def impl(x):
x[0] *= GENERIC_TARGET_OL_CALLS_TARGET_OL
target_overloaded(x)
return impl
@overload(target_overloaded_calls_target_overloaded, target='cuda')
def ol_generic_calls_target_overloaded_cuda(x):
def impl(x):
x[0] *= CUDA_TARGET_OL_CALLS_TARGET_OL
target_overloaded(x)
return impl
@skip_on_cudasim('Overloading not supported in cudasim')
class TestOverload(CUDATestCase):
def check_overload(self, kernel, expected):
x = np.ones(1, dtype=np.int32)
cuda.jit(kernel)[1, 1](x)
self.assertEqual(x[0], expected)
def check_overload_cpu(self, kernel, expected):
x = np.ones(1, dtype=np.int32)
njit(kernel)(x)
self.assertEqual(x[0], expected)
def test_generic(self):
def kernel(x):
generic_func_1(x)
expected = GENERIC_FUNCTION_1
self.check_overload(kernel, expected)
def test_cuda(self):
def kernel(x):
cuda_func_1(x)
expected = CUDA_FUNCTION_1
self.check_overload(kernel, expected)
def test_generic_and_cuda(self):
def kernel(x):
generic_func_1(x)
cuda_func_1(x)
expected = GENERIC_FUNCTION_1 * CUDA_FUNCTION_1
self.check_overload(kernel, expected)
def test_call_two_generic_calls(self):
def kernel(x):
generic_func_1(x)
generic_func_2(x)
expected = GENERIC_FUNCTION_1 * GENERIC_FUNCTION_2
self.check_overload(kernel, expected)
def test_call_two_cuda_calls(self):
def kernel(x):
cuda_func_1(x)
cuda_func_2(x)
expected = CUDA_FUNCTION_1 * CUDA_FUNCTION_2
self.check_overload(kernel, expected)
def test_generic_calls_generic(self):
def kernel(x):
generic_calls_generic(x)
expected = GENERIC_CALLS_GENERIC * GENERIC_FUNCTION_1
self.check_overload(kernel, expected)
def test_generic_calls_cuda(self):
def kernel(x):
generic_calls_cuda(x)
expected = GENERIC_CALLS_CUDA * CUDA_FUNCTION_1
self.check_overload(kernel, expected)
def test_cuda_calls_generic(self):
def kernel(x):
cuda_calls_generic(x)
expected = CUDA_CALLS_GENERIC * GENERIC_FUNCTION_1
self.check_overload(kernel, expected)
def test_cuda_calls_cuda(self):
def kernel(x):
cuda_calls_cuda(x)
expected = CUDA_CALLS_CUDA * CUDA_FUNCTION_1
self.check_overload(kernel, expected)
def test_call_target_overloaded(self):
def kernel(x):
target_overloaded(x)
expected = CUDA_TARGET_OL
self.check_overload(kernel, expected)
def test_generic_calls_target_overloaded(self):
def kernel(x):
generic_calls_target_overloaded(x)
expected = GENERIC_CALLS_TARGET_OL * CUDA_TARGET_OL
self.check_overload(kernel, expected)
def test_cuda_calls_target_overloaded(self):
def kernel(x):
cuda_calls_target_overloaded(x)
expected = CUDA_CALLS_TARGET_OL * CUDA_TARGET_OL
self.check_overload(kernel, expected)
def test_target_overloaded_calls_target_overloaded(self):
def kernel(x):
target_overloaded_calls_target_overloaded(x)
# Check the CUDA overloads are used on CUDA
expected = CUDA_TARGET_OL_CALLS_TARGET_OL * CUDA_TARGET_OL
self.check_overload(kernel, expected)
# Also check that the CPU overloads are used on the CPU
expected = GENERIC_TARGET_OL_CALLS_TARGET_OL * GENERIC_TARGET_OL
self.check_overload_cpu(kernel, expected)
if __name__ == '__main__':
unittest.main()