125 lines
4.4 KiB
Python
125 lines
4.4 KiB
Python
|
import unittest
|
||
|
|
||
|
from numba import njit
|
||
|
from numba.tests.support import TestCase, override_config
|
||
|
from numba.misc import llvm_pass_timings as lpt
|
||
|
|
||
|
|
||
|
timings_raw1 = """
|
||
|
===-------------------------------------------------------------------------===
|
||
|
... Pass execution timing report ...
|
||
|
===-------------------------------------------------------------------------===
|
||
|
Total Execution Time: 0.0001 seconds (0.0001 wall clock)
|
||
|
|
||
|
---User Time--- --System Time-- --User+System-- ---Wall Time--- --- Name ---
|
||
|
0.0001 ( 90.1%) 0.0001 ( 90.1%) 0.0001 ( 90.1%) 0.0001 ( 90.1%) A1
|
||
|
0.0000 ( 9.9%) 0.0000 ( 9.9%) 0.0000 ( 9.9%) 0.0000 ( 9.9%) A2
|
||
|
0.0001 (100.0%) 0.0001 (100.0%) 0.0001 (100.0%) 0.0001 (100.0%) Total
|
||
|
|
||
|
""" # noqa: E501
|
||
|
|
||
|
timings_raw2 = """
|
||
|
===-------------------------------------------------------------------------===
|
||
|
... Pass execution timing report ...
|
||
|
===-------------------------------------------------------------------------===
|
||
|
Total Execution Time: 0.0001 seconds (0.0001 wall clock)
|
||
|
|
||
|
---User Time--- --System Time-- --User+System-- ---Wall Time--- --- Name ---
|
||
|
0.0001 ( 90.1%) ----- 0.0001 ( 90.1%) 0.0001 ( 90.1%) A1
|
||
|
0.0000 ( 9.9%) ----- 0.0000 ( 9.9%) 0.0000 ( 9.9%) A2
|
||
|
0.0001 (100.0%) ----- 0.0001 (100.0%) 0.0001 (100.0%) Total
|
||
|
|
||
|
""" # noqa: E501
|
||
|
|
||
|
|
||
|
class TestLLVMPassTimings(TestCase):
|
||
|
|
||
|
def test_usage(self):
|
||
|
@njit
|
||
|
def foo(n):
|
||
|
c = 0
|
||
|
for i in range(n):
|
||
|
c += i
|
||
|
return c
|
||
|
|
||
|
with override_config('LLVM_PASS_TIMINGS', True):
|
||
|
foo(10)
|
||
|
|
||
|
md = foo.get_metadata(foo.signatures[0])
|
||
|
timings = md['llvm_pass_timings']
|
||
|
# Check: timing is of correct type
|
||
|
self.assertIsInstance(timings, lpt.PassTimingsCollection)
|
||
|
# Check: basic for __str__
|
||
|
text = str(timings)
|
||
|
self.assertIn("Module passes (full optimization)", text)
|
||
|
# Check: there must be more than one record
|
||
|
self.assertGreater(len(timings), 0)
|
||
|
# Check: __getitem__
|
||
|
last = timings[-1]
|
||
|
self.assertIsInstance(last, lpt.NamedTimings)
|
||
|
# Check: NamedTimings
|
||
|
self.assertIsInstance(last.name, str)
|
||
|
self.assertIsInstance(last.timings, lpt.ProcessedPassTimings)
|
||
|
|
||
|
def test_analyze(self):
|
||
|
@njit
|
||
|
def foo(n):
|
||
|
c = 0
|
||
|
for i in range(n):
|
||
|
for j in range(i):
|
||
|
c += j
|
||
|
return c
|
||
|
|
||
|
with override_config('LLVM_PASS_TIMINGS', True):
|
||
|
foo(10)
|
||
|
|
||
|
md = foo.get_metadata(foo.signatures[0])
|
||
|
timings_collection = md['llvm_pass_timings']
|
||
|
# Check: get_total_time()
|
||
|
self.assertIsInstance(timings_collection.get_total_time(), float)
|
||
|
# Check: summary()
|
||
|
self.assertIsInstance(timings_collection.summary(), str)
|
||
|
# Check: list_longest_first() ordering
|
||
|
longest_first = timings_collection.list_longest_first()
|
||
|
self.assertEqual(len(longest_first), len(timings_collection))
|
||
|
last = longest_first[0].timings.get_total_time()
|
||
|
for rec in longest_first[1:]:
|
||
|
cur = rec.timings.get_total_time()
|
||
|
self.assertGreaterEqual(last, cur)
|
||
|
cur = last
|
||
|
|
||
|
def test_parse_raw(self):
|
||
|
timings1 = lpt.ProcessedPassTimings(timings_raw1)
|
||
|
self.assertAlmostEqual(timings1.get_total_time(), 0.0001)
|
||
|
self.assertIsInstance(timings1.summary(), str)
|
||
|
|
||
|
timings2 = lpt.ProcessedPassTimings(timings_raw2)
|
||
|
self.assertAlmostEqual(timings2.get_total_time(), 0.0001)
|
||
|
self.assertIsInstance(timings2.summary(), str)
|
||
|
|
||
|
|
||
|
class TestLLVMPassTimingsDisabled(TestCase):
|
||
|
def test_disabled_behavior(self):
|
||
|
@njit
|
||
|
def foo(n):
|
||
|
c = 0
|
||
|
for i in range(n):
|
||
|
c += i
|
||
|
return c
|
||
|
|
||
|
with override_config('LLVM_PASS_TIMINGS', False):
|
||
|
foo(10)
|
||
|
|
||
|
md = foo.get_metadata(foo.signatures[0])
|
||
|
timings = md['llvm_pass_timings']
|
||
|
# Check that the right message is returned
|
||
|
self.assertEqual(timings.summary(), "No pass timings were recorded")
|
||
|
# Check that None is returned
|
||
|
self.assertIsNone(timings.get_total_time())
|
||
|
# Check that empty list is returned
|
||
|
self.assertEqual(timings.list_longest_first(), [])
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|