ai-content-maker/.venv/Lib/site-packages/numba/cuda/tests/cudapy/test_inspect.py

166 lines
5.1 KiB
Python

import numpy as np
from io import StringIO
from numba import cuda, float32, float64, int32, intp
from numba.cuda.testing import unittest, CUDATestCase
from numba.cuda.testing import (skip_on_cudasim, skip_with_nvdisasm,
skip_without_nvdisasm)
@skip_on_cudasim('Simulator does not generate code to be inspected')
class TestInspect(CUDATestCase):
@property
def cc(self):
return cuda.current_context().device.compute_capability
def test_monotyped(self):
sig = (float32, int32)
@cuda.jit(sig)
def foo(x, y):
pass
file = StringIO()
foo.inspect_types(file=file)
typeanno = file.getvalue()
# Function name in annotation
self.assertIn("foo", typeanno)
# Signature in annotation
self.assertIn("(float32, int32)", typeanno)
file.close()
# Function name in LLVM
llvm = foo.inspect_llvm(sig)
self.assertIn("foo", llvm)
# Kernel in LLVM
self.assertIn('cuda.kernel.wrapper', llvm)
# Wrapped device function body in LLVM
self.assertIn("define linkonce_odr i32", llvm)
asm = foo.inspect_asm(sig)
# Function name in PTX
self.assertIn("foo", asm)
# NVVM inserted comments in PTX
self.assertIn("Generated by NVIDIA NVVM Compiler", asm)
def test_polytyped(self):
@cuda.jit
def foo(x, y):
pass
foo[1, 1](1, 1)
foo[1, 1](1.2, 2.4)
file = StringIO()
foo.inspect_types(file=file)
typeanno = file.getvalue()
file.close()
# Signature in annotation
self.assertIn("({0}, {0})".format(intp), typeanno)
self.assertIn("(float64, float64)", typeanno)
# Signature in LLVM dict
llvmirs = foo.inspect_llvm()
self.assertEqual(2, len(llvmirs), )
self.assertIn((intp, intp), llvmirs)
self.assertIn((float64, float64), llvmirs)
# Function name in LLVM
self.assertIn("foo", llvmirs[intp, intp])
self.assertIn("foo", llvmirs[float64, float64])
# Kernels in LLVM
self.assertIn('cuda.kernel.wrapper', llvmirs[intp, intp])
self.assertIn('cuda.kernel.wrapper', llvmirs[float64, float64])
# Wrapped device function bodies in LLVM
self.assertIn("define linkonce_odr i32", llvmirs[intp, intp])
self.assertIn("define linkonce_odr i32", llvmirs[float64, float64])
asmdict = foo.inspect_asm()
# Signature in assembly dict
self.assertEqual(2, len(asmdict), )
self.assertIn((intp, intp), asmdict)
self.assertIn((float64, float64), asmdict)
# NVVM inserted in PTX
self.assertIn("foo", asmdict[intp, intp])
self.assertIn("foo", asmdict[float64, float64])
def _test_inspect_sass(self, kernel, name, sass):
# Ensure function appears in output
seen_function = False
for line in sass.split():
if '.text' in line and name in line:
seen_function = True
self.assertTrue(seen_function)
self.assertRegex(sass, r'//## File ".*/test_inspect.py", line [0-9]')
# Some instructions common to all supported architectures that should
# appear in the output
self.assertIn('S2R', sass) # Special register to register
self.assertIn('BRA', sass) # Branch
self.assertIn('EXIT', sass) # Exit program
@skip_without_nvdisasm('nvdisasm needed for inspect_sass()')
def test_inspect_sass_eager(self):
sig = (float32[::1], int32[::1])
@cuda.jit(sig, lineinfo=True)
def add(x, y):
i = cuda.grid(1)
if i < len(x):
x[i] += y[i]
self._test_inspect_sass(add, 'add', add.inspect_sass(sig))
@skip_without_nvdisasm('nvdisasm needed for inspect_sass()')
def test_inspect_sass_lazy(self):
@cuda.jit(lineinfo=True)
def add(x, y):
i = cuda.grid(1)
if i < len(x):
x[i] += y[i]
x = np.arange(10).astype(np.int32)
y = np.arange(10).astype(np.float32)
add[1, 10](x, y)
signature = (int32[::1], float32[::1])
self._test_inspect_sass(add, 'add', add.inspect_sass(signature))
@skip_with_nvdisasm('Missing nvdisasm exception only generated when it is '
'not present')
def test_inspect_sass_nvdisasm_missing(self):
@cuda.jit((float32[::1],))
def f(x):
x[0] = 0
with self.assertRaises(RuntimeError) as raises:
f.inspect_sass()
self.assertIn('nvdisasm has not been found', str(raises.exception))
@skip_without_nvdisasm('nvdisasm needed for inspect_sass_cfg()')
def test_inspect_sass_cfg(self):
sig = (float32[::1], int32[::1])
@cuda.jit(sig)
def add(x, y):
i = cuda.grid(1)
if i < len(x):
x[i] += y[i]
self.assertRegex(
add.inspect_sass_cfg(signature=sig),
r'digraph\s*\w\s*{(.|\n)*\n}'
)
if __name__ == '__main__':
unittest.main()