ai-content-maker/.venv/Lib/site-packages/numba/cuda/tests/cudadrv/test_ptds.py

150 lines
4.8 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import multiprocessing as mp
import logging
import traceback
from numba.cuda.testing import unittest, CUDATestCase
from numba.cuda.testing import (skip_on_cudasim, skip_with_cuda_python,
skip_under_cuda_memcheck)
from numba.tests.support import linux_only
def child_test():
from numba import cuda, int32, void
from numba.core import config
import io
import numpy as np
import threading
# Enable PTDS before we make any CUDA driver calls. Enabling it first
# ensures that PTDS APIs are used because the CUDA driver looks up API
# functions on first use and memoizes them.
config.CUDA_PER_THREAD_DEFAULT_STREAM = 1
# Set up log capture for the Driver API so we can see what API calls were
# used.
logbuf = io.StringIO()
handler = logging.StreamHandler(logbuf)
cudadrv_logger = logging.getLogger('numba.cuda.cudadrv.driver')
cudadrv_logger.addHandler(handler)
cudadrv_logger.setLevel(logging.DEBUG)
# Set up data for our test, and copy over to the device
N = 2 ** 16
N_THREADS = 10
N_ADDITIONS = 4096
# Seed the RNG for repeatability
np.random.seed(1)
x = np.random.randint(low=0, high=1000, size=N, dtype=np.int32)
r = np.zeros_like(x)
# One input and output array for each thread
xs = [cuda.to_device(x) for _ in range(N_THREADS)]
rs = [cuda.to_device(r) for _ in range(N_THREADS)]
# Compute the grid size and get the [per-thread] default stream
n_threads = 256
n_blocks = N // n_threads
stream = cuda.default_stream()
# A simple multiplication-by-addition kernel. What it does exactly is not
# too important; only that we have a kernel that does something.
@cuda.jit(void(int32[::1], int32[::1]))
def f(r, x):
i = cuda.grid(1)
if i > len(r):
return
# Accumulate x into r
for j in range(N_ADDITIONS):
r[i] += x[i]
# This function will be used to launch the kernel from each thread on its
# own unique data.
def kernel_thread(n):
f[n_blocks, n_threads, stream](rs[n], xs[n])
# Create threads
threads = [threading.Thread(target=kernel_thread, args=(i,))
for i in range(N_THREADS)]
# Start all threads
for thread in threads:
thread.start()
# Wait for all threads to finish, to ensure that we don't synchronize with
# the device until all kernels are scheduled.
for thread in threads:
thread.join()
# Synchronize with the device
cuda.synchronize()
# Check output is as expected
expected = x * N_ADDITIONS
for i in range(N_THREADS):
np.testing.assert_equal(rs[i].copy_to_host(), expected)
# Return the driver log output to the calling process for checking
handler.flush()
return logbuf.getvalue()
def child_test_wrapper(result_queue):
try:
output = child_test()
success = True
# Catch anything raised so it can be propagated
except: # noqa: E722
output = traceback.format_exc()
success = False
result_queue.put((success, output))
# Run on Linux only until the reason for test hangs on Windows (Issue #8635,
# https://github.com/numba/numba/issues/8635) is diagnosed
@linux_only
@skip_under_cuda_memcheck('Hangs cuda-memcheck')
@skip_on_cudasim('Streams not supported on the simulator')
class TestPTDS(CUDATestCase):
@skip_with_cuda_python('Function names unchanged for PTDS with NV Binding')
def test_ptds(self):
# Run a test with PTDS enabled in a child process
ctx = mp.get_context('spawn')
result_queue = ctx.Queue()
proc = ctx.Process(target=child_test_wrapper, args=(result_queue,))
proc.start()
proc.join()
success, output = result_queue.get()
# Ensure the child process ran to completion before checking its output
if not success:
self.fail(output)
# Functions with a per-thread default stream variant that we expect to
# see in the output
ptds_functions = ('cuMemcpyHtoD_v2_ptds', 'cuLaunchKernel_ptsz',
'cuMemcpyDtoH_v2_ptds')
for fn in ptds_functions:
with self.subTest(fn=fn, expected=True):
self.assertIn(fn, output)
# Non-PTDS versions of the functions that we should not see in the
# output:
legacy_functions = ('cuMemcpyHtoD_v2', 'cuLaunchKernel',
'cuMemcpyDtoH_v2')
for fn in legacy_functions:
with self.subTest(fn=fn, expected=False):
# Ensure we only spot these function names appearing without a
# _ptds or _ptsz suffix by checking including the end of the
# line in the log
fn_at_end = f'{fn}\n'
self.assertNotIn(fn_at_end, output)
if __name__ == '__main__':
unittest.main()