182 lines
4.9 KiB
Python
182 lines
4.9 KiB
Python
"""
|
|
Tests for enum support.
|
|
"""
|
|
|
|
|
|
import numpy as np
|
|
import unittest
|
|
from numba import jit, vectorize, int8, int16, int32
|
|
|
|
from numba.tests.support import TestCase
|
|
from numba.tests.enum_usecases import (Color, Shape, Shake,
|
|
Planet, RequestError,
|
|
IntEnumWithNegatives)
|
|
|
|
|
|
def compare_usecase(a, b):
|
|
return a == b, a != b, a is b, a is not b
|
|
|
|
|
|
def getattr_usecase(a):
|
|
# Lookup of a enum member on its class
|
|
return a is Color.red
|
|
|
|
|
|
def getitem_usecase(a):
|
|
"""Lookup enum member by string name"""
|
|
return a is Color['red']
|
|
|
|
|
|
def identity_usecase(a, b, c):
|
|
return (a is Shake.mint,
|
|
b is Shape.circle,
|
|
c is RequestError.internal_error,
|
|
)
|
|
|
|
|
|
def make_constant_usecase(const):
|
|
def constant_usecase(a):
|
|
return a is const
|
|
return constant_usecase
|
|
|
|
|
|
def return_usecase(a, b, pred):
|
|
return a if pred else b
|
|
|
|
|
|
def int_coerce_usecase(x):
|
|
# Implicit coercion of intenums to ints
|
|
if x > RequestError.internal_error:
|
|
return x - RequestError.not_found
|
|
else:
|
|
return x + Shape.circle
|
|
|
|
def int_cast_usecase(x):
|
|
# Explicit coercion of intenums to ints
|
|
if x > int16(RequestError.internal_error):
|
|
return x - int32(RequestError.not_found)
|
|
else:
|
|
return x + int8(Shape.circle)
|
|
|
|
|
|
def vectorize_usecase(x):
|
|
if x != RequestError.not_found:
|
|
return RequestError['internal_error']
|
|
else:
|
|
return RequestError.dummy
|
|
|
|
|
|
class BaseEnumTest(object):
|
|
|
|
def test_compare(self):
|
|
pyfunc = compare_usecase
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
|
|
for args in self.pairs:
|
|
self.assertPreciseEqual(pyfunc(*args), cfunc(*args))
|
|
|
|
def test_return(self):
|
|
"""
|
|
Passing and returning enum members.
|
|
"""
|
|
pyfunc = return_usecase
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
|
|
for pair in self.pairs:
|
|
for pred in (True, False):
|
|
args = pair + (pred,)
|
|
self.assertIs(pyfunc(*args), cfunc(*args))
|
|
|
|
def check_constant_usecase(self, pyfunc):
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
|
|
for arg in self.values:
|
|
self.assertPreciseEqual(pyfunc(arg), cfunc(arg))
|
|
|
|
def test_constant(self):
|
|
self.check_constant_usecase(getattr_usecase)
|
|
self.check_constant_usecase(getitem_usecase)
|
|
self.check_constant_usecase(make_constant_usecase(self.values[0]))
|
|
|
|
|
|
class TestEnum(BaseEnumTest, TestCase):
|
|
"""
|
|
Tests for Enum classes and members.
|
|
"""
|
|
values = [Color.red, Color.green]
|
|
|
|
pairs = [
|
|
(Color.red, Color.red),
|
|
(Color.red, Color.green),
|
|
(Shake.mint, Shake.vanilla),
|
|
(Planet.VENUS, Planet.MARS),
|
|
(Planet.EARTH, Planet.EARTH),
|
|
]
|
|
|
|
def test_identity(self):
|
|
"""
|
|
Enum with equal values should not compare identical
|
|
"""
|
|
pyfunc = identity_usecase
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
args = (Color.blue, Color.green, Shape.square)
|
|
self.assertPreciseEqual(pyfunc(*args), cfunc(*args))
|
|
|
|
|
|
class TestIntEnum(BaseEnumTest, TestCase):
|
|
"""
|
|
Tests for IntEnum classes and members.
|
|
"""
|
|
values = [Shape.circle, Shape.square]
|
|
|
|
pairs = [
|
|
(Shape.circle, Shape.circle),
|
|
(Shape.circle, Shape.square),
|
|
(RequestError.not_found, RequestError.not_found),
|
|
(RequestError.internal_error, RequestError.not_found),
|
|
]
|
|
|
|
def test_int_coerce(self):
|
|
pyfunc = int_coerce_usecase
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
|
|
for arg in [300, 450, 550]:
|
|
self.assertPreciseEqual(pyfunc(arg), cfunc(arg))
|
|
|
|
def test_int_cast(self):
|
|
pyfunc = int_cast_usecase
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
|
|
for arg in [300, 450, 550]:
|
|
self.assertPreciseEqual(pyfunc(arg), cfunc(arg))
|
|
|
|
def test_vectorize(self):
|
|
cfunc = vectorize(nopython=True)(vectorize_usecase)
|
|
arg = np.array([2, 404, 500, 404])
|
|
sol = np.array([vectorize_usecase(i) for i in arg], dtype=arg.dtype)
|
|
self.assertPreciseEqual(sol, cfunc(arg))
|
|
|
|
def test_hash(self):
|
|
def pyfun(x):
|
|
return hash(x)
|
|
cfunc = jit(nopython=True)(pyfun)
|
|
for member in IntEnumWithNegatives:
|
|
self.assertPreciseEqual(pyfun(member), cfunc(member))
|
|
|
|
def test_int_shape_cast(self):
|
|
def pyfun_empty(x):
|
|
return np.empty((x, x), dtype='int64').fill(-1)
|
|
def pyfun_zeros(x):
|
|
return np.zeros((x, x), dtype='int64')
|
|
def pyfun_ones(x):
|
|
return np.ones((x, x), dtype='int64')
|
|
for pyfun in [pyfun_empty, pyfun_zeros, pyfun_ones]:
|
|
cfunc = jit(nopython=True)(pyfun)
|
|
for member in IntEnumWithNegatives:
|
|
if member >= 0:
|
|
self.assertPreciseEqual(pyfun(member), cfunc(member))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|