ai-content-maker/.venv/Lib/site-packages/numba/tests/test_enums.py

182 lines
4.9 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
"""
Tests for enum support.
"""
import numpy as np
import unittest
from numba import jit, vectorize, int8, int16, int32
from numba.tests.support import TestCase
from numba.tests.enum_usecases import (Color, Shape, Shake,
Planet, RequestError,
IntEnumWithNegatives)
def compare_usecase(a, b):
return a == b, a != b, a is b, a is not b
def getattr_usecase(a):
# Lookup of a enum member on its class
return a is Color.red
def getitem_usecase(a):
"""Lookup enum member by string name"""
return a is Color['red']
def identity_usecase(a, b, c):
return (a is Shake.mint,
b is Shape.circle,
c is RequestError.internal_error,
)
def make_constant_usecase(const):
def constant_usecase(a):
return a is const
return constant_usecase
def return_usecase(a, b, pred):
return a if pred else b
def int_coerce_usecase(x):
# Implicit coercion of intenums to ints
if x > RequestError.internal_error:
return x - RequestError.not_found
else:
return x + Shape.circle
def int_cast_usecase(x):
# Explicit coercion of intenums to ints
if x > int16(RequestError.internal_error):
return x - int32(RequestError.not_found)
else:
return x + int8(Shape.circle)
def vectorize_usecase(x):
if x != RequestError.not_found:
return RequestError['internal_error']
else:
return RequestError.dummy
class BaseEnumTest(object):
def test_compare(self):
pyfunc = compare_usecase
cfunc = jit(nopython=True)(pyfunc)
for args in self.pairs:
self.assertPreciseEqual(pyfunc(*args), cfunc(*args))
def test_return(self):
"""
Passing and returning enum members.
"""
pyfunc = return_usecase
cfunc = jit(nopython=True)(pyfunc)
for pair in self.pairs:
for pred in (True, False):
args = pair + (pred,)
self.assertIs(pyfunc(*args), cfunc(*args))
def check_constant_usecase(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
for arg in self.values:
self.assertPreciseEqual(pyfunc(arg), cfunc(arg))
def test_constant(self):
self.check_constant_usecase(getattr_usecase)
self.check_constant_usecase(getitem_usecase)
self.check_constant_usecase(make_constant_usecase(self.values[0]))
class TestEnum(BaseEnumTest, TestCase):
"""
Tests for Enum classes and members.
"""
values = [Color.red, Color.green]
pairs = [
(Color.red, Color.red),
(Color.red, Color.green),
(Shake.mint, Shake.vanilla),
(Planet.VENUS, Planet.MARS),
(Planet.EARTH, Planet.EARTH),
]
def test_identity(self):
"""
Enum with equal values should not compare identical
"""
pyfunc = identity_usecase
cfunc = jit(nopython=True)(pyfunc)
args = (Color.blue, Color.green, Shape.square)
self.assertPreciseEqual(pyfunc(*args), cfunc(*args))
class TestIntEnum(BaseEnumTest, TestCase):
"""
Tests for IntEnum classes and members.
"""
values = [Shape.circle, Shape.square]
pairs = [
(Shape.circle, Shape.circle),
(Shape.circle, Shape.square),
(RequestError.not_found, RequestError.not_found),
(RequestError.internal_error, RequestError.not_found),
]
def test_int_coerce(self):
pyfunc = int_coerce_usecase
cfunc = jit(nopython=True)(pyfunc)
for arg in [300, 450, 550]:
self.assertPreciseEqual(pyfunc(arg), cfunc(arg))
def test_int_cast(self):
pyfunc = int_cast_usecase
cfunc = jit(nopython=True)(pyfunc)
for arg in [300, 450, 550]:
self.assertPreciseEqual(pyfunc(arg), cfunc(arg))
def test_vectorize(self):
cfunc = vectorize(nopython=True)(vectorize_usecase)
arg = np.array([2, 404, 500, 404])
sol = np.array([vectorize_usecase(i) for i in arg], dtype=arg.dtype)
self.assertPreciseEqual(sol, cfunc(arg))
def test_hash(self):
def pyfun(x):
return hash(x)
cfunc = jit(nopython=True)(pyfun)
for member in IntEnumWithNegatives:
self.assertPreciseEqual(pyfun(member), cfunc(member))
def test_int_shape_cast(self):
def pyfun_empty(x):
return np.empty((x, x), dtype='int64').fill(-1)
def pyfun_zeros(x):
return np.zeros((x, x), dtype='int64')
def pyfun_ones(x):
return np.ones((x, x), dtype='int64')
for pyfun in [pyfun_empty, pyfun_zeros, pyfun_ones]:
cfunc = jit(nopython=True)(pyfun)
for member in IntEnumWithNegatives:
if member >= 0:
self.assertPreciseEqual(pyfun(member), cfunc(member))
if __name__ == '__main__':
unittest.main()