190 lines
6.0 KiB
Python
190 lines
6.0 KiB
Python
import sys
|
|
|
|
import numpy as np
|
|
|
|
import unittest
|
|
from numba import jit, njit
|
|
from numba.core import types, errors, utils
|
|
from numba.tests.support import (captured_stdout, TestCase, EnableNRTStatsMixin)
|
|
|
|
|
|
def print_value(x):
|
|
print(x)
|
|
|
|
def print_array_item(arr, i):
|
|
print(arr[i].x)
|
|
|
|
def print_values(a, b, c):
|
|
print(a, b, c)
|
|
|
|
def print_empty():
|
|
print()
|
|
|
|
def print_string(x):
|
|
print(x, "hop!", 3.5)
|
|
|
|
def print_vararg(a, b, c):
|
|
print(a, b, *c)
|
|
|
|
def print_string_vararg(a, b, c):
|
|
print(a, "hop!", b, *c)
|
|
|
|
def make_print_closure(x):
|
|
def print_closure():
|
|
return x
|
|
return jit(nopython=True)(x)
|
|
|
|
|
|
class TestPrint(EnableNRTStatsMixin, TestCase):
|
|
|
|
def check_values(self, typ, values):
|
|
cfunc = njit((typ,))(print_value)
|
|
for val in values:
|
|
with captured_stdout():
|
|
cfunc(val)
|
|
self.assertEqual(sys.stdout.getvalue(), str(val) + '\n')
|
|
|
|
def test_print_values(self):
|
|
"""
|
|
Test printing a single argument value.
|
|
"""
|
|
# Various scalars
|
|
self.check_values(types.int32, (1, -234))
|
|
self.check_values(types.int64, (1, -234,
|
|
123456789876543210,
|
|
-123456789876543210))
|
|
self.check_values(types.uint64, (1, 234,
|
|
123456789876543210, 2**63 + 123))
|
|
self.check_values(types.boolean, (True, False))
|
|
self.check_values(types.float64, (1.5, 100.0**10.0, float('nan')))
|
|
self.check_values(types.complex64, (1+1j,))
|
|
self.check_values(types.NPTimedelta('ms'), (np.timedelta64(100, 'ms'),))
|
|
|
|
cfunc = njit((types.float32,))(print_value)
|
|
with captured_stdout():
|
|
cfunc(1.1)
|
|
# Float32 will lose precision
|
|
got = sys.stdout.getvalue()
|
|
expect = '1.10000002384'
|
|
self.assertTrue(got.startswith(expect))
|
|
self.assertTrue(got.endswith('\n'))
|
|
|
|
# Test array
|
|
arraytype = types.Array(types.int32, 1, 'C')
|
|
cfunc = njit((arraytype,))(print_value)
|
|
with captured_stdout():
|
|
cfunc(np.arange(10, dtype=np.int32))
|
|
self.assertEqual(sys.stdout.getvalue(),
|
|
'[0 1 2 3 4 5 6 7 8 9]\n')
|
|
|
|
@unittest.skip("Issue with intermittent NRT leak, see #9355.")
|
|
def test_print_nrt_type(self):
|
|
# NOTE: this check is extracted from the above as it started
|
|
# intermittently leaking since the merge of #9330 (compile_isolated
|
|
# removal patch). It's not clear why this happens, see #9355 for
|
|
# thoughts/details. This test is skipped until it is resolved.
|
|
|
|
# NRT-enabled type
|
|
with self.assertNoNRTLeak():
|
|
x = [1, 3, 5, 7]
|
|
with self.assertRefCount(x):
|
|
self.check_values(types.List(types.intp, reflected=True), (x,))
|
|
|
|
def test_print_array_item(self):
|
|
"""
|
|
Test printing a Numpy character sequence
|
|
"""
|
|
dtype = np.dtype([('x', 'S4')])
|
|
arr = np.frombuffer(bytearray(range(1, 9)), dtype=dtype)
|
|
|
|
pyfunc = print_array_item
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
for i in range(len(arr)):
|
|
with captured_stdout():
|
|
cfunc(arr, i)
|
|
self.assertEqual(sys.stdout.getvalue(), str(arr[i]['x']) + '\n')
|
|
|
|
def test_print_multiple_values(self):
|
|
pyfunc = print_values
|
|
cfunc = njit((types.intp,) * 3)(pyfunc)
|
|
with captured_stdout():
|
|
cfunc(1, 2, 3)
|
|
self.assertEqual(sys.stdout.getvalue(), '1 2 3\n')
|
|
|
|
def test_print_nogil(self):
|
|
pyfunc = print_values
|
|
cfunc = jit(nopython=True, nogil=True)(pyfunc)
|
|
with captured_stdout():
|
|
cfunc(1, 2, 3)
|
|
self.assertEqual(sys.stdout.getvalue(), '1 2 3\n')
|
|
|
|
def test_print_empty(self):
|
|
pyfunc = print_empty
|
|
cfunc = njit((),)(pyfunc)
|
|
with captured_stdout():
|
|
cfunc()
|
|
self.assertEqual(sys.stdout.getvalue(), '\n')
|
|
|
|
def test_print_strings(self):
|
|
pyfunc = print_string
|
|
cfunc = njit((types.intp,))(pyfunc)
|
|
with captured_stdout():
|
|
cfunc(1)
|
|
self.assertEqual(sys.stdout.getvalue(), '1 hop! 3.5\n')
|
|
|
|
def test_print_vararg(self):
|
|
# Test *args support for print(). This is desired since
|
|
# print() can use a dedicated IR node.
|
|
pyfunc = print_vararg
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
with captured_stdout():
|
|
cfunc(1, (2, 3), (4, 5j))
|
|
self.assertEqual(sys.stdout.getvalue(), '1 (2, 3) 4 5j\n')
|
|
|
|
pyfunc = print_string_vararg
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
with captured_stdout():
|
|
cfunc(1, (2, 3), (4, 5j))
|
|
self.assertEqual(sys.stdout.getvalue(), '1 hop! (2, 3) 4 5j\n')
|
|
|
|
def test_inner_fn_print(self):
|
|
@jit(nopython=True)
|
|
def foo(x):
|
|
print(x)
|
|
|
|
@jit(nopython=True)
|
|
def bar(x):
|
|
foo(x)
|
|
foo('hello')
|
|
|
|
# Printing an array requires the Env.
|
|
# We need to make sure the inner function can obtain the Env.
|
|
x = np.arange(5)
|
|
with captured_stdout():
|
|
bar(x)
|
|
self.assertEqual(sys.stdout.getvalue(), '[0 1 2 3 4]\nhello\n')
|
|
|
|
def test_print_w_kwarg_raises(self):
|
|
@jit(nopython=True)
|
|
def print_kwarg():
|
|
print('x', flush=True)
|
|
|
|
with self.assertRaises(errors.UnsupportedError) as raises:
|
|
print_kwarg()
|
|
expected = ("Numba's print() function implementation does not support "
|
|
"keyword arguments.")
|
|
self.assertIn(raises.exception.msg, expected)
|
|
|
|
def test_print_no_truncation(self):
|
|
''' See: https://github.com/numba/numba/issues/3811
|
|
'''
|
|
@jit(nopython=True)
|
|
def foo():
|
|
print(''.join(['a'] * 10000))
|
|
with captured_stdout():
|
|
foo()
|
|
self.assertEqual(sys.stdout.getvalue(), ''.join(['a'] * 10000) + '\n')
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|