135 lines
3.7 KiB
Python
135 lines
3.7 KiB
Python
import numpy as np
|
|
|
|
from numba import jit
|
|
from numba.core import types
|
|
from numba.tests.support import TestCase, tag
|
|
import unittest
|
|
|
|
|
|
# Array overlaps involving a displacement
|
|
|
|
def array_overlap1(src, dest, k=1):
|
|
assert src.shape == dest.shape
|
|
dest[k:] = src[:-k]
|
|
|
|
def array_overlap2(src, dest, k=1):
|
|
assert src.shape == dest.shape
|
|
dest[:-k] = src[k:]
|
|
|
|
def array_overlap3(src, dest, k=1):
|
|
assert src.shape == dest.shape
|
|
dest[:,:-k] = src[:,k:]
|
|
|
|
def array_overlap4(src, dest, k=1):
|
|
assert src.shape == dest.shape
|
|
dest[:,k:] = src[:,:-k]
|
|
|
|
def array_overlap5(src, dest, k=1):
|
|
assert src.shape == dest.shape
|
|
dest[...,:-k] = src[...,k:]
|
|
|
|
def array_overlap6(src, dest, k=1):
|
|
assert src.shape == dest.shape
|
|
dest[...,k:] = src[...,:-k]
|
|
|
|
# Array overlaps involving an in-place reversal
|
|
|
|
def array_overlap11(src, dest):
|
|
assert src.shape == dest.shape
|
|
dest[::-1] = src
|
|
|
|
def array_overlap12(src, dest):
|
|
assert src.shape == dest.shape
|
|
dest[:] = src[::-1]
|
|
|
|
def array_overlap13(src, dest):
|
|
assert src.shape == dest.shape
|
|
dest[:,::-1] = src
|
|
|
|
def array_overlap14(src, dest):
|
|
assert src.shape == dest.shape
|
|
dest[:] = src[:,::-1]
|
|
|
|
def array_overlap15(src, dest):
|
|
assert src.shape == dest.shape
|
|
dest[...,::-1] = src
|
|
|
|
def array_overlap16(src, dest):
|
|
assert src.shape == dest.shape
|
|
dest[:] = src[...,::-1]
|
|
|
|
|
|
class TestArrayOverlap(TestCase):
|
|
|
|
def check_overlap(self, pyfunc, min_ndim, have_k_argument=False):
|
|
N = 4
|
|
|
|
def vary_layouts(orig):
|
|
yield orig.copy(order='C')
|
|
yield orig.copy(order='F')
|
|
a = orig[::-1].copy()[::-1]
|
|
assert not a.flags.c_contiguous and not a.flags.f_contiguous
|
|
yield a
|
|
|
|
def check(pyfunc, cfunc, pydest, cdest, kwargs):
|
|
pyfunc(pydest, pydest, **kwargs)
|
|
cfunc(cdest, cdest, **kwargs)
|
|
self.assertPreciseEqual(pydest, cdest)
|
|
|
|
cfunc = jit(nopython=True)(pyfunc)
|
|
# Check for up to 3d arrays
|
|
for ndim in range(min_ndim, 4):
|
|
shape = (N,) * ndim
|
|
orig = np.arange(0, N**ndim).reshape(shape)
|
|
# Note we cannot copy a 'A' layout array exactly (bitwise),
|
|
# so instead we call vary_layouts() twice
|
|
for pydest, cdest in zip(vary_layouts(orig), vary_layouts(orig)):
|
|
if have_k_argument:
|
|
for k in range(1, N):
|
|
check(pyfunc, cfunc, pydest, cdest, dict(k=k))
|
|
else:
|
|
check(pyfunc, cfunc, pydest, cdest, {})
|
|
|
|
def check_overlap_with_k(self, pyfunc, min_ndim):
|
|
self.check_overlap(pyfunc, min_ndim=min_ndim, have_k_argument=True)
|
|
|
|
def test_overlap1(self):
|
|
self.check_overlap_with_k(array_overlap1, min_ndim=1)
|
|
|
|
def test_overlap2(self):
|
|
self.check_overlap_with_k(array_overlap2, min_ndim=1)
|
|
|
|
def test_overlap3(self):
|
|
self.check_overlap_with_k(array_overlap3, min_ndim=2)
|
|
|
|
def test_overlap4(self):
|
|
self.check_overlap_with_k(array_overlap4, min_ndim=2)
|
|
|
|
def test_overlap5(self):
|
|
self.check_overlap_with_k(array_overlap5, min_ndim=1)
|
|
|
|
def test_overlap6(self):
|
|
self.check_overlap_with_k(array_overlap6, min_ndim=1)
|
|
|
|
def test_overlap11(self):
|
|
self.check_overlap(array_overlap11, min_ndim=1)
|
|
|
|
def test_overlap12(self):
|
|
self.check_overlap(array_overlap12, min_ndim=1)
|
|
|
|
def test_overlap13(self):
|
|
self.check_overlap(array_overlap13, min_ndim=2)
|
|
|
|
def test_overlap14(self):
|
|
self.check_overlap(array_overlap14, min_ndim=2)
|
|
|
|
def test_overlap15(self):
|
|
self.check_overlap(array_overlap15, min_ndim=1)
|
|
|
|
def test_overlap16(self):
|
|
self.check_overlap(array_overlap16, min_ndim=1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|