677 lines
21 KiB
Python
677 lines
21 KiB
Python
"""
|
|
Tests for sub-components of parfors.
|
|
These tests are aimed to produce a good-enough coverage of parfor passes
|
|
so that refactoring on these passes are easier with faster testing turnaround.
|
|
"""
|
|
import unittest
|
|
from functools import reduce
|
|
|
|
import numpy as np
|
|
|
|
from numba import njit, typeof, prange, pndindex
|
|
import numba.parfors.parfor
|
|
from numba.core import (
|
|
rewrites,
|
|
typed_passes,
|
|
untyped_passes,
|
|
inline_closurecall,
|
|
compiler,
|
|
cpu,
|
|
errors
|
|
)
|
|
from numba.core.registry import cpu_target
|
|
from numba.tests.support import (TestCase, is_parfors_unsupported)
|
|
|
|
|
|
class MyPipeline(object):
|
|
def __init__(self, typingctx, targetctx, args, test_ir):
|
|
self.state = compiler.StateDict()
|
|
self.state.typingctx = typingctx
|
|
self.state.targetctx = targetctx
|
|
self.state.args = args
|
|
self.state.func_ir = test_ir
|
|
self.state.typemap = None
|
|
self.state.return_type = None
|
|
self.state.calltypes = None
|
|
self.state.metadata = {}
|
|
|
|
|
|
class BaseTest(TestCase):
|
|
@classmethod
|
|
def _run_parfor(cls, test_func, args, swap_map=None):
|
|
# TODO: refactor this with get_optimized_numba_ir() where this is
|
|
# copied from
|
|
typingctx = cpu_target.typing_context
|
|
targetctx = cpu_target.target_context
|
|
test_ir = compiler.run_frontend(test_func)
|
|
options = cpu.ParallelOptions(True)
|
|
|
|
tp = MyPipeline(typingctx, targetctx, args, test_ir)
|
|
|
|
typingctx.refresh()
|
|
targetctx.refresh()
|
|
|
|
inline_pass = inline_closurecall.InlineClosureCallPass(
|
|
tp.state.func_ir, options, typed=True
|
|
)
|
|
inline_pass.run()
|
|
|
|
rewrites.rewrite_registry.apply("before-inference", tp.state)
|
|
|
|
untyped_passes.ReconstructSSA().run_pass(tp.state)
|
|
|
|
(
|
|
tp.state.typemap,
|
|
tp.state.return_type,
|
|
tp.state.calltypes,
|
|
_
|
|
) = typed_passes.type_inference_stage(
|
|
tp.state.typingctx, tp.state.targetctx, tp.state.func_ir,
|
|
tp.state.args, None
|
|
)
|
|
|
|
typed_passes.PreLowerStripPhis().run_pass(tp.state)
|
|
|
|
diagnostics = numba.parfors.parfor.ParforDiagnostics()
|
|
|
|
preparfor_pass = numba.parfors.parfor.PreParforPass(
|
|
tp.state.func_ir,
|
|
tp.state.typemap,
|
|
tp.state.calltypes,
|
|
tp.state.typingctx,
|
|
tp.state.targetctx,
|
|
options,
|
|
swapped=diagnostics.replaced_fns,
|
|
replace_functions_map=swap_map,
|
|
)
|
|
preparfor_pass.run()
|
|
|
|
rewrites.rewrite_registry.apply("after-inference", tp.state)
|
|
return tp, options, diagnostics, preparfor_pass
|
|
|
|
@classmethod
|
|
def run_parfor_sub_pass(cls, test_func, args):
|
|
tp, options, diagnostics, _ = cls._run_parfor(test_func, args)
|
|
|
|
flags = compiler.Flags()
|
|
parfor_pass = numba.parfors.parfor.ParforPass(
|
|
tp.state.func_ir,
|
|
tp.state.typemap,
|
|
tp.state.calltypes,
|
|
tp.state.return_type,
|
|
tp.state.typingctx,
|
|
tp.state.targetctx,
|
|
options,
|
|
flags,
|
|
tp.state.metadata,
|
|
diagnostics=diagnostics,
|
|
)
|
|
parfor_pass._pre_run()
|
|
# Run subpass
|
|
sub_pass = cls.sub_pass_class(parfor_pass)
|
|
sub_pass.run(parfor_pass.func_ir.blocks)
|
|
|
|
return sub_pass
|
|
|
|
@classmethod
|
|
def run_parfor_pre_pass(cls, test_func, args, swap_map=None):
|
|
tp, options, diagnostics, preparfor_pass = cls._run_parfor(
|
|
test_func, args, swap_map
|
|
)
|
|
return preparfor_pass
|
|
|
|
def _run_parallel(self, func, *args, **kwargs):
|
|
cfunc = njit(parallel=True)(func)
|
|
expect = func(*args, **kwargs)
|
|
got = cfunc(*args, **kwargs)
|
|
return expect, got
|
|
|
|
def run_parallel(self, func, *args, **kwargs):
|
|
if is_parfors_unsupported:
|
|
# Skip
|
|
return
|
|
expect, got = self._run_parallel(func, *args, **kwargs)
|
|
self.assertPreciseEqual(expect, got)
|
|
|
|
def run_parallel_check_output_array(self, func, *args, **kwargs):
|
|
if is_parfors_unsupported:
|
|
# Skip
|
|
return
|
|
expect, got = self._run_parallel(func, *args, **kwargs)
|
|
# Don't match the value, just the return type. must return array
|
|
self.assertIsInstance(expect, np.ndarray)
|
|
self.assertIsInstance(got, np.ndarray)
|
|
self.assertEqual(expect.shape, got.shape)
|
|
|
|
def check_records(self, records):
|
|
for rec in records:
|
|
self.assertIsInstance(rec["new"], numba.parfors.parfor.Parfor)
|
|
|
|
|
|
class TestConvertSetItemPass(BaseTest):
|
|
sub_pass_class = numba.parfors.parfor.ConvertSetItemPass
|
|
|
|
def test_setitem_full_slice(self):
|
|
def test_impl():
|
|
n = 10
|
|
a = np.ones(n)
|
|
a[:] = 7
|
|
return a
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "slice")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_setitem_slice_stop_bound(self):
|
|
def test_impl():
|
|
n = 10
|
|
a = np.ones(n)
|
|
a[:5] = 7
|
|
return a
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "slice")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_setitem_slice_start_bound(self):
|
|
def test_impl():
|
|
n = 10
|
|
a = np.ones(n)
|
|
a[4:] = 7
|
|
return a
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "slice")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_setitem_gather_if_scalar(self):
|
|
def test_impl():
|
|
n = 10
|
|
a = np.ones(n)
|
|
b = np.ones_like(a, dtype=np.bool_)
|
|
a[b] = 7
|
|
return a
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "masked_assign_broadcast_scalar")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_setitem_gather_if_array(self):
|
|
def test_impl():
|
|
n = 10
|
|
a = np.ones(n)
|
|
b = np.ones_like(a, dtype=np.bool_)
|
|
c = np.ones_like(a)
|
|
a[b] = c[b]
|
|
return a
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "masked_assign_array")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
|
|
class TestConvertNumpyPass(BaseTest):
|
|
sub_pass_class = numba.parfors.parfor.ConvertNumpyPass
|
|
|
|
def check_numpy_allocators(self, fn):
|
|
def test_impl():
|
|
n = 10
|
|
a = fn(n)
|
|
return a
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "numpy_allocator")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def check_numpy_random(self, fn):
|
|
def test_impl():
|
|
n = 10
|
|
a = fn(n)
|
|
return a
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "numpy_allocator")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel_check_output_array(test_impl)
|
|
|
|
def test_numpy_allocators(self):
|
|
fns = [np.ones, np.zeros]
|
|
for fn in fns:
|
|
with self.subTest(fn.__name__):
|
|
self.check_numpy_allocators(fn)
|
|
|
|
def test_numpy_random(self):
|
|
fns = [np.random.random]
|
|
for fn in fns:
|
|
with self.subTest(fn.__name__):
|
|
self.check_numpy_random(fn)
|
|
|
|
def test_numpy_arrayexpr(self):
|
|
def test_impl(a, b):
|
|
return a + b
|
|
|
|
a = b = np.ones(10)
|
|
|
|
args = (a, b)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "arrayexpr")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_numpy_arrayexpr_ufunc(self):
|
|
def test_impl(a, b):
|
|
return np.sin(-a) + np.float64(1) / np.sqrt(b)
|
|
|
|
a = b = np.ones(10)
|
|
|
|
args = (a, b)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "arrayexpr")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_numpy_arrayexpr_boardcast(self):
|
|
def test_impl(a, b):
|
|
return a + b + np.array(1)
|
|
|
|
a = np.ones(10)
|
|
b = np.ones((3, 10))
|
|
|
|
args = (a, b)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "arrayexpr")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_numpy_arrayexpr_reshaped(self):
|
|
def test_impl(a, b):
|
|
a = a.reshape(1, a.size) # shape[0] is now constant
|
|
return a + b
|
|
|
|
a = np.ones(10)
|
|
b = np.ones(10)
|
|
|
|
args = (a, b)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "arrayexpr")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
|
|
class TestConvertReducePass(BaseTest):
|
|
sub_pass_class = numba.parfors.parfor.ConvertReducePass
|
|
|
|
def test_reduce_max_basic(self):
|
|
def test_impl(arr):
|
|
return reduce(lambda x, y: max(x, y), arr, 0.0)
|
|
|
|
x = np.ones(10)
|
|
args = (x,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "reduce")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_reduce_max_masked(self):
|
|
def test_impl(arr):
|
|
return reduce(lambda x, y: max(x, y), arr[arr > 5], 0.0)
|
|
|
|
x = np.ones(10)
|
|
args = (x,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "reduce")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
|
|
class TestConvertLoopPass(BaseTest):
|
|
sub_pass_class = numba.parfors.parfor.ConvertLoopPass
|
|
|
|
def test_prange_reduce_simple(self):
|
|
def test_impl():
|
|
n = 20
|
|
c = 0
|
|
for i in prange(n):
|
|
c += i
|
|
return c
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_prange_map_simple(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones(n)
|
|
for i in prange(n):
|
|
arr[i] += i
|
|
return arr
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_prange_two_args(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones(n)
|
|
for i in prange(3, n):
|
|
arr[i] += i
|
|
return arr
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_prange_three_args(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones(n)
|
|
for i in prange(3, n, 2):
|
|
arr[i] += i
|
|
return arr
|
|
|
|
with self.assertRaises(errors.UnsupportedRewriteError) as raises:
|
|
self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertIn(
|
|
"Only constant step size of 1 is supported for prange",
|
|
str(raises.exception),
|
|
)
|
|
|
|
def test_prange_map_inner_loop(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones((n, n))
|
|
for i in prange(n):
|
|
for j in range(i):
|
|
arr[i, j] += i + j * n
|
|
return arr
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_prange_map_nested_prange(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones((n, n))
|
|
for i in prange(n):
|
|
for j in prange(i):
|
|
arr[i, j] += i + j * n
|
|
return arr
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 2)
|
|
self.check_records(sub_pass.rewritten)
|
|
for record in sub_pass.rewritten:
|
|
self.assertEqual(record["reason"], "loop")
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_prange_map_none_index(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones(n)
|
|
for i in prange(n):
|
|
inner = arr[i : i + 1]
|
|
inner[()] += 1
|
|
return arr
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
self.check_records(sub_pass.rewritten)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_prange_map_overwrite_index(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones(n)
|
|
for i in prange(n):
|
|
i += 1
|
|
arr[i - 1] = i
|
|
return arr
|
|
|
|
with self.assertRaises(errors.UnsupportedRewriteError) as raises:
|
|
self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertIn(
|
|
"Overwrite of parallel loop index",
|
|
str(raises.exception),
|
|
)
|
|
|
|
def test_init_prange(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones(n)
|
|
numba.parfors.parfor.init_prange()
|
|
val = 0
|
|
for i in numba.parfors.parfor.internal_prange(len(arr)):
|
|
val += arr[i]
|
|
return val
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
self.check_records(sub_pass.rewritten)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_pndindex(self):
|
|
def test_impl():
|
|
n = 20
|
|
arr = np.ones((n, n))
|
|
val = 0
|
|
for idx in pndindex(arr.shape):
|
|
val += idx[0] * idx[1]
|
|
return val
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, ())
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
self.check_records(sub_pass.rewritten)
|
|
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
|
|
self.run_parallel(test_impl)
|
|
|
|
def test_numpy_sum(self):
|
|
def test_impl(arr):
|
|
return np.sum(arr)
|
|
|
|
shape = 11, 13
|
|
arr = np.arange(np.prod(shape)).reshape(shape)
|
|
args = (arr,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
self.check_records(sub_pass.rewritten)
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_numpy_sum_bool_array_masked(self):
|
|
def test_impl(arr):
|
|
sliced = arr[:, 0]
|
|
return np.sum(arr[sliced >= 3, 1:2])
|
|
|
|
shape = 11, 13
|
|
arr = np.arange(np.prod(shape)).reshape(shape)
|
|
args = (arr,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
self.check_records(sub_pass.rewritten)
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_numpy_sum_int_array_masked(self):
|
|
def test_impl(arr):
|
|
sel = np.arange(arr.shape[1])
|
|
return np.sum(arr[:, sel])
|
|
|
|
shape = 11, 13
|
|
arr = np.arange(np.prod(shape)).reshape(shape)
|
|
args = (arr,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
# 1 for arange; 1 for sum
|
|
self.assertEqual(len(sub_pass.rewritten), 2)
|
|
for record in sub_pass.rewritten:
|
|
self.assertEqual(record["reason"], "loop")
|
|
self.check_records(sub_pass.rewritten)
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_numpy_fill_method(self):
|
|
def test_impl(arr):
|
|
arr.fill(3)
|
|
return arr
|
|
|
|
shape = 11, 13
|
|
arr = np.arange(np.prod(shape)).reshape(shape)
|
|
args = (arr,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
sub_pass = self.run_parfor_sub_pass(test_impl, argtypes)
|
|
# 1 for arange; 1 for sum
|
|
self.assertEqual(len(sub_pass.rewritten), 1)
|
|
[record] = sub_pass.rewritten
|
|
self.assertEqual(record["reason"], "loop")
|
|
self.check_records(sub_pass.rewritten)
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
|
|
class TestPreParforPass(BaseTest):
|
|
class sub_pass_class:
|
|
def __init__(self, pass_states):
|
|
pass
|
|
|
|
def run(self, blocks):
|
|
pass
|
|
|
|
def test_dtype_conversion(self):
|
|
# array.dtype are converted to np.dtype(array) in the PreParforPass
|
|
def test_impl(a):
|
|
b = np.ones(20, dtype=a.dtype)
|
|
return b
|
|
|
|
arr = np.arange(10)
|
|
args = (arr,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
pre_pass = self.run_parfor_pre_pass(test_impl, argtypes)
|
|
self.assertEqual(pre_pass.stats["replaced_func"], 0)
|
|
self.assertEqual(pre_pass.stats["replaced_dtype"], 1)
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_sum_replacement(self):
|
|
def test_impl(a):
|
|
return np.sum(a)
|
|
|
|
arr = np.arange(10)
|
|
args = (arr,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
pre_pass = self.run_parfor_pre_pass(test_impl, argtypes)
|
|
self.assertEqual(pre_pass.stats["replaced_func"], 1)
|
|
self.assertEqual(pre_pass.stats["replaced_dtype"], 0)
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
def test_replacement_map(self):
|
|
def test_impl(a):
|
|
return np.sum(a)
|
|
|
|
arr = np.arange(10)
|
|
args = (arr,)
|
|
argtypes = [typeof(x) for x in args]
|
|
|
|
swap_map = numba.parfors.parfor.swap_functions_map.copy()
|
|
swap_map.pop(("sum", "numpy"))
|
|
pre_pass = self.run_parfor_pre_pass(test_impl, argtypes, swap_map)
|
|
self.assertEqual(pre_pass.stats["replaced_func"], 0)
|
|
self.assertEqual(pre_pass.stats["replaced_dtype"], 0)
|
|
self.run_parallel(test_impl, *args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|