ai-content-maker/.venv/Lib/site-packages/numba/tests/test_conditions_as_predicat...

194 lines
4.8 KiB
Python

from numba.tests.support import TestCase
from numba import njit, types
from numba.typed import List, Dict
import numpy as np
class TestConditionsAsPredicates(TestCase):
def test_scalars(self):
# checks that scalar types can be used as predicates
dts = [np.int8, np.uint16, np.int64, np.float32, np.float64,
np.complex128, int, float, complex, str, bool]
for dt in dts:
for c in 1, 0:
x = dt(c)
@njit
def foo():
if x:
return 10
else:
return 20
self.assertEqual(foo(), foo.py_func())
self.assertEqual(foo(), 10 if c == 1 or dt is str else 20)
# empty string
@njit
def foo(x):
if x:
return 10
else:
return 20
s = ""
self.assertEqual(foo(s), foo.py_func(s))
def test_typed_list(self):
@njit
def foo(x):
if x:
return 10
else:
return 20
# empty list
z = List.empty_list(types.int64)
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 20)
# non-empty list
z.append(1)
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 10)
def test_reflected_list(self):
# non-empty
@njit
def foo(x):
if x:
return 10
else:
return 20
z = [1]
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 10)
# non-empty local
@njit
def foo():
y = [1, 2]
if y:
return 10
else:
return 20
self.assertEqual(foo(), foo.py_func())
self.assertEqual(foo.py_func(), 10)
# empty local
@njit
def foo():
y = [1, 2]
y.pop()
y.pop()
assert len(y) == 0
if y:
return 10
else:
return 20
self.assertEqual(foo(), foo.py_func())
self.assertEqual(foo.py_func(), 20)
def test_reflected_set(self):
# non-empty
@njit
def foo(x):
if x:
return 10
else:
return 20
z = {1}
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 10)
# non-empty local
@njit
def foo():
y = {1, 2}
if y:
return 10
else:
return 20
self.assertEqual(foo(), foo.py_func())
self.assertEqual(foo.py_func(), 10)
# empty local
@njit
def foo():
y = {1, 2}
y.pop()
y.pop()
assert len(y) == 0
if y:
return 10
else:
return 20
self.assertEqual(foo(), foo.py_func())
self.assertEqual(foo.py_func(), 20)
def test_typed_dict(self):
@njit
def foo(x):
if x:
return 10
else:
return 20
# empty
z = Dict.empty(types.int64, types.int64)
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 20)
# non-empty
z[2] = 3
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 10)
def test_arrays(self):
@njit
def foo(x):
if x:
return 10
else:
return 20
# non-empty 0d, True
z = np.array(1)
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 10)
# non-empty 0d, False
z = np.array(0)
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 20)
# non-empty nd True
z = np.array([[[1]]])
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 10)
# non-empty nd False
z = np.array([[[0]]])
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 20)
# various problems:
# empty, NumPy warns
z = np.empty(0)
self.assertEqual(foo(z), foo.py_func(z))
self.assertEqual(foo.py_func(z), 20)
# nd, NumPy raises
z = np.array([1, 2])
with self.assertRaises(ValueError) as raises:
foo(z)
msg = ("The truth value of an array with more than one element is "
"ambiguous. Use a.any() or a.all()")
self.assertIn(msg, str(raises.exception))