253 lines
7.4 KiB
Python
253 lines
7.4 KiB
Python
import re
|
|
from io import StringIO
|
|
|
|
import numba
|
|
from numba.core import types
|
|
from numba import jit, njit
|
|
from numba.tests.support import override_config, TestCase
|
|
import unittest
|
|
|
|
try:
|
|
import jinja2
|
|
except ImportError:
|
|
jinja2 = None
|
|
|
|
try:
|
|
import pygments
|
|
except ImportError:
|
|
pygments = None
|
|
|
|
|
|
@unittest.skipIf(jinja2 is None, "please install the 'jinja2' package")
|
|
class TestAnnotation(TestCase):
|
|
|
|
@TestCase.run_test_in_subprocess # annotations compound per module
|
|
def test_exercise_code_path(self):
|
|
"""
|
|
Ensures template.html is available
|
|
"""
|
|
|
|
def foo(n, a):
|
|
s = a
|
|
for i in range(n):
|
|
s += i
|
|
return s
|
|
|
|
cfunc = njit((types.int32, types.int32))(foo)
|
|
cres = cfunc.overloads[cfunc.signatures[0]]
|
|
ta = cres.type_annotation
|
|
|
|
buf = StringIO()
|
|
ta.html_annotate(buf)
|
|
output = buf.getvalue()
|
|
buf.close()
|
|
self.assertIn("foo", output)
|
|
|
|
@TestCase.run_test_in_subprocess # annotations compound per module
|
|
def test_exercise_code_path_with_lifted_loop(self):
|
|
"""
|
|
Ensures that lifted loops are handled correctly in obj mode
|
|
"""
|
|
# the functions to jit
|
|
def bar(x):
|
|
return x
|
|
|
|
def foo(x):
|
|
h = 0.
|
|
for i in range(x): # py 38 needs two loops for one to lift?!
|
|
h = h + i
|
|
for k in range(x):
|
|
h = h + k
|
|
if x:
|
|
h = h - bar(x)
|
|
return h
|
|
|
|
# compile into an isolated context
|
|
cfunc = jit((types.intp,), forceobj=True, looplift=True)(foo)
|
|
cres = cfunc.overloads[cfunc.signatures[0]]
|
|
|
|
ta = cres.type_annotation
|
|
|
|
buf = StringIO()
|
|
ta.html_annotate(buf)
|
|
output = buf.getvalue()
|
|
buf.close()
|
|
self.assertIn("bar", output)
|
|
self.assertIn("foo", output)
|
|
self.assertIn("LiftedLoop", output)
|
|
|
|
@TestCase.run_test_in_subprocess # annotations compound per module
|
|
def test_html_output_with_lifted_loop(self):
|
|
"""
|
|
Test some format and behavior of the html annotation with lifted loop
|
|
"""
|
|
@numba.jit(forceobj=True)
|
|
def udt(x):
|
|
object() # to force object mode
|
|
z = 0
|
|
for i in range(x): # this line is tagged
|
|
z += i
|
|
return z
|
|
|
|
# Regex pattern to check for the "lifted_tag" in the line of the loop
|
|
re_lifted_tag = re.compile(
|
|
r'<td class="lifted_tag">\s*'
|
|
r'\s*<details>'
|
|
r'\s*<summary>'
|
|
r'\s*<code>'
|
|
r'\s*[0-9]+:'
|
|
r'\s*[ ]+for i in range\(x\): # this line is tagged\s*',
|
|
re.MULTILINE)
|
|
|
|
# Compile int64 version
|
|
sig_i64 = (types.int64,)
|
|
udt.compile(sig_i64) # compile with lifted loop
|
|
cres = udt.overloads[sig_i64]
|
|
|
|
# Make html output
|
|
buf = StringIO()
|
|
cres.type_annotation.html_annotate(buf)
|
|
output = buf.getvalue()
|
|
buf.close()
|
|
|
|
# There should be only one function output.
|
|
self.assertEqual(output.count("Function name: udt"), 1)
|
|
|
|
sigfmt = "with signature: {} -> pyobject"
|
|
self.assertEqual(output.count(sigfmt.format(sig_i64)), 1)
|
|
# Ensure the loop is tagged
|
|
self.assertEqual(len(re.findall(re_lifted_tag, output)), 1,
|
|
msg='%s not found in %s' % (re_lifted_tag, output))
|
|
|
|
# Compile float64 version
|
|
sig_f64 = (types.float64,)
|
|
udt.compile(sig_f64)
|
|
cres = udt.overloads[sig_f64]
|
|
|
|
# Make html output
|
|
buf = StringIO()
|
|
cres.type_annotation.html_annotate(buf)
|
|
output = buf.getvalue()
|
|
buf.close()
|
|
|
|
# There should be two function output
|
|
self.assertEqual(output.count("Function name: udt"), 2)
|
|
self.assertEqual(output.count(sigfmt.format(sig_i64)), 1)
|
|
self.assertEqual(output.count(sigfmt.format(sig_f64)), 1)
|
|
# Ensure the loop is tagged in both output
|
|
self.assertEqual(len(re.findall(re_lifted_tag, output)), 2)
|
|
|
|
@unittest.skipIf(pygments is None, "please install the 'pygments' package")
|
|
def test_pretty_print(self):
|
|
|
|
@numba.njit
|
|
def foo(x, y):
|
|
return x, y
|
|
|
|
foo(1, 2)
|
|
# Exercise the method
|
|
foo.inspect_types(pretty=True)
|
|
|
|
# Exercise but supply a not None file kwarg, this is invalid
|
|
with self.assertRaises(ValueError) as raises:
|
|
foo.inspect_types(pretty=True, file='should be None')
|
|
self.assertIn('`file` must be None if `pretty=True`',
|
|
str(raises.exception))
|
|
|
|
|
|
class TestTypeAnnotation(unittest.TestCase):
|
|
|
|
def findpatloc(self, lines, pat):
|
|
for i, ln in enumerate(lines):
|
|
if pat in ln:
|
|
return i
|
|
raise ValueError("can't find {!r}".format(pat))
|
|
|
|
def getlines(self, func):
|
|
strbuf = StringIO()
|
|
func.inspect_types(strbuf)
|
|
return strbuf.getvalue().splitlines()
|
|
|
|
def test_delete(self):
|
|
@numba.njit
|
|
def foo(appleorange, berrycherry):
|
|
return appleorange + berrycherry
|
|
|
|
foo(1, 2)
|
|
|
|
lines = self.getlines(foo)
|
|
|
|
# Ensure deletion show up after their use
|
|
sa = self.findpatloc(lines, 'appleorange = arg(0, name=appleorange)')
|
|
sb = self.findpatloc(lines, 'berrycherry = arg(1, name=berrycherry)')
|
|
|
|
ea = self.findpatloc(lines, 'del appleorange')
|
|
eb = self.findpatloc(lines, 'del berrycherry')
|
|
|
|
self.assertLess(sa, ea)
|
|
self.assertLess(sb, eb)
|
|
|
|
def _lifetimes_impl(self, extend):
|
|
with override_config('EXTEND_VARIABLE_LIFETIMES', extend):
|
|
@njit
|
|
def foo(a):
|
|
b = a
|
|
return b
|
|
x = 10
|
|
b = foo(x)
|
|
self.assertEqual(b, x)
|
|
|
|
lines = self.getlines(foo)
|
|
|
|
sa = self.findpatloc(lines, 'a = arg(0, name=a)')
|
|
sb = self.findpatloc(lines, 'b = a')
|
|
|
|
cast_ret = self.findpatloc(lines, 'cast(value=b)')
|
|
|
|
dela = self.findpatloc(lines, 'del a')
|
|
delb = self.findpatloc(lines, 'del b')
|
|
|
|
return sa, sb, cast_ret, dela, delb
|
|
|
|
def test_delete_standard_lifetimes(self):
|
|
# without extended lifetimes, dels occur as soon as dead
|
|
#
|
|
# label 0
|
|
# a = arg(0, name=a) :: int64
|
|
# b = a :: int64
|
|
# del a
|
|
# $8return_value.2 = cast(value=b) :: int64
|
|
# del b
|
|
# return $8return_value.2
|
|
|
|
sa, sb, cast_ret, dela, delb = self._lifetimes_impl(extend=0)
|
|
|
|
self.assertLess(sa, dela)
|
|
self.assertLess(sb, delb)
|
|
# del a is before cast and del b is after
|
|
self.assertLess(dela, cast_ret)
|
|
self.assertGreater(delb, cast_ret)
|
|
|
|
def test_delete_extended_lifetimes(self):
|
|
# with extended lifetimes, dels are last in block:
|
|
#
|
|
# label 0
|
|
# a = arg(0, name=a) :: int64
|
|
# b = a :: int64
|
|
# $8return_value.2 = cast(value=b) :: int64
|
|
# del a
|
|
# del b
|
|
# return $8return_value.2
|
|
|
|
sa, sb, cast_ret, dela, delb = self._lifetimes_impl(extend=1)
|
|
|
|
self.assertLess(sa, dela)
|
|
self.assertLess(sb, delb)
|
|
# dels are after the cast
|
|
self.assertGreater(dela, cast_ret)
|
|
self.assertGreater(delb, cast_ret)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|