ai-content-maker/.venv/Lib/site-packages/numba/cuda/tests/cudapy/test_extending.py

156 lines
4.0 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase
import numpy as np
from numba import config, cuda, njit, types
class Interval:
"""
A half-open interval on the real number line.
"""
def __init__(self, lo, hi):
self.lo = lo
self.hi = hi
def __repr__(self):
return 'Interval(%f, %f)' % (self.lo, self.hi)
@property
def width(self):
return self.hi - self.lo
@njit
def interval_width(interval):
return interval.width
@njit
def sum_intervals(i, j):
return Interval(i.lo + j.lo, i.hi + j.hi)
if not config.ENABLE_CUDASIM:
from numba.core import cgutils
from numba.core.extending import (lower_builtin, make_attribute_wrapper,
models, register_model, type_callable,
typeof_impl)
from numba.core.typing.templates import AttributeTemplate
from numba.cuda.cudadecl import registry as cuda_registry
from numba.cuda.cudaimpl import lower_attr as cuda_lower_attr
class IntervalType(types.Type):
def __init__(self):
super().__init__(name='Interval')
interval_type = IntervalType()
@typeof_impl.register(Interval)
def typeof_interval(val, c):
return interval_type
@type_callable(Interval)
def type_interval(context):
def typer(lo, hi):
if isinstance(lo, types.Float) and isinstance(hi, types.Float):
return interval_type
return typer
@register_model(IntervalType)
class IntervalModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('lo', types.float64),
('hi', types.float64),
]
models.StructModel.__init__(self, dmm, fe_type, members)
make_attribute_wrapper(IntervalType, 'lo', 'lo')
make_attribute_wrapper(IntervalType, 'hi', 'hi')
@lower_builtin(Interval, types.Float, types.Float)
def impl_interval(context, builder, sig, args):
typ = sig.return_type
lo, hi = args
interval = cgutils.create_struct_proxy(typ)(context, builder)
interval.lo = lo
interval.hi = hi
return interval._getvalue()
@cuda_registry.register_attr
class Interval_attrs(AttributeTemplate):
key = IntervalType
def resolve_width(self, mod):
return types.float64
@cuda_lower_attr(IntervalType, 'width')
def cuda_Interval_width(context, builder, sig, arg):
lo = builder.extract_value(arg, 0)
hi = builder.extract_value(arg, 1)
return builder.fsub(hi, lo)
@skip_on_cudasim('Extensions not supported in the simulator')
class TestExtending(CUDATestCase):
def test_attributes(self):
@cuda.jit
def f(r, x):
iv = Interval(x[0], x[1])
r[0] = iv.lo
r[1] = iv.hi
x = np.asarray((1.5, 2.5))
r = np.zeros_like(x)
f[1, 1](r, x)
np.testing.assert_equal(r, x)
def test_property(self):
@cuda.jit
def f(r, x):
iv = Interval(x[0], x[1])
r[0] = iv.width
x = np.asarray((1.5, 2.5))
r = np.zeros(1)
f[1, 1](r, x)
np.testing.assert_allclose(r[0], x[1] - x[0])
def test_extension_type_as_arg(self):
@cuda.jit
def f(r, x):
iv = Interval(x[0], x[1])
r[0] = interval_width(iv)
x = np.asarray((1.5, 2.5))
r = np.zeros(1)
f[1, 1](r, x)
np.testing.assert_allclose(r[0], x[1] - x[0])
def test_extension_type_as_retvalue(self):
@cuda.jit
def f(r, x):
iv1 = Interval(x[0], x[1])
iv2 = Interval(x[2], x[3])
iv_sum = sum_intervals(iv1, iv2)
r[0] = iv_sum.lo
r[1] = iv_sum.hi
x = np.asarray((1.5, 2.5, 3.0, 4.0))
r = np.zeros(2)
f[1, 1](r, x)
expected = np.asarray((x[0] + x[2], x[1] + x[3]))
np.testing.assert_allclose(r, expected)
if __name__ == '__main__':
unittest.main()