ai-content-maker/.venv/Lib/site-packages/numba/tests/npyufunc/test_update_inplace.py

123 lines
5.0 KiB
Python

# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division
import unittest
import numpy as np
from numba import guvectorize
from numba.tests.support import TestCase
def py_replace_2nd(x_t, y_1):
for t in range(0, x_t.shape[0], 2):
x_t[t] = y_1[0]
def py_update_3(x0_t, x1_t, x2_t, y_1):
for t in range(0, x0_t.shape[0]):
x0_t[t] = y_1[0]
x1_t[t] = 2 * y_1[0]
x2_t[t] = 3 * y_1[0]
class TestUpdateInplace(TestCase):
def _run_test_for_gufunc(self, gufunc, py_func, expect_f4_to_pass=True,
z=2):
for dtype, expect_to_pass in [('f8', True), ('f4', expect_f4_to_pass)]:
inputs = [np.zeros(10, dtype) for _ in range(gufunc.nin - 1)]
ex_inputs = [x_t.copy() for x_t in inputs]
gufunc(*inputs, z)
py_func(*ex_inputs, np.array([z]))
for i, (x_t, ex_x_t) in enumerate(zip(inputs, ex_inputs)):
if expect_to_pass:
np.testing.assert_equal(x_t, ex_x_t, err_msg='input %s' % i)
else:
self.assertFalse((x_t == ex_x_t).all(), msg='input %s' % i)
def test_update_inplace(self):
# test without writable_args
gufunc = guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True)(py_replace_2nd)
self._run_test_for_gufunc(gufunc, py_replace_2nd,
expect_f4_to_pass=False)
# test with writable_args
gufunc = guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True, writable_args=(0,))(py_replace_2nd)
self._run_test_for_gufunc(gufunc, py_replace_2nd)
# test with writable_args as strings
gufunc = guvectorize(['void(f8[:], f8[:])'], '(t),()', nopython=True,
writable_args=('x_t',))(py_replace_2nd)
self._run_test_for_gufunc(gufunc, py_replace_2nd)
def test_update_inplace_with_cache(self):
# test with writable_args
gufunc = guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True, writable_args=(0,),
cache=True)(py_replace_2nd)
# 2nd time it is loaded from cache
gufunc = guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True, writable_args=(0,),
cache=True)(py_replace_2nd)
self._run_test_for_gufunc(gufunc, py_replace_2nd)
def test_update_inplace_parallel(self):
# test with writable_args
gufunc = guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True, writable_args=(0,),
target='parallel')(py_replace_2nd)
self._run_test_for_gufunc(gufunc, py_replace_2nd)
def test_update_inplace_3(self):
# test without writable_args
gufunc = guvectorize(['void(f8[:], f8[:], f8[:], f8[:])'],
'(t),(t),(t),()',
nopython=True)(py_update_3)
self._run_test_for_gufunc(gufunc, py_update_3, expect_f4_to_pass=False)
# test with writable_args
gufunc = guvectorize(['void(f8[:], f8[:], f8[:], f8[:])'],
'(t),(t),(t),()', nopython=True,
writable_args=(0, 1, 2))(py_update_3)
self._run_test_for_gufunc(gufunc, py_update_3)
# test with writable_args as mix of strings and ints
gufunc = guvectorize(['void(f8[:], f8[:], f8[:], f8[:])'],
'(t),(t),(t),()', nopython=True,
writable_args=('x0_t', 'x1_t', 2))(py_update_3)
self._run_test_for_gufunc(gufunc, py_update_3)
def test_exceptions(self):
# check that len(writable_args) <= nin
with self.assertRaises(ValueError):
guvectorize(['void(f8[:], f8[:])'], '(t),()', nopython=True,
writable_args=(0, 1, 2, 5))(py_replace_2nd)
# check that all values in writable_args are between 0 and nin
with self.assertRaises(ValueError):
guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True, writable_args=(5,))(py_replace_2nd)
with self.assertRaises(ValueError):
guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True, writable_args=(-1,))(py_replace_2nd)
# check that exception is raised when passing non-existing argument name
with self.assertRaises(RuntimeError):
guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True, writable_args=('z_t',))(py_replace_2nd)
# writable_args are not supported for target='cuda'
with self.assertRaises(TypeError):
guvectorize(['void(f8[:], f8[:])'], '(t),()',
nopython=True, writable_args=(0,),
target='cuda')(py_replace_2nd)
if __name__ == '__main__':
unittest.main()