ai-content-maker/.venv/Lib/site-packages/numba/tests/test_num_threads.py

633 lines
23 KiB
Python

# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division
import sys
import os
import re
import multiprocessing
import unittest
import numpy as np
from numba import (njit, set_num_threads, get_num_threads, prange, config,
threading_layer, guvectorize)
from numba.np.ufunc.parallel import get_thread_id
from numba.core.errors import TypingError
from numba.tests.support import TestCase, skip_parfors_unsupported, tag
from numba.tests.test_parallel_backend import TestInSubprocess
class TestNumThreads(TestCase):
_numba_parallel_test_ = False
def setUp(self):
# Make sure the num_threads is set to the max. This also makes sure
# the threads are launched.
set_num_threads(config.NUMBA_NUM_THREADS)
def check_mask(self, expected, result):
# There's no guarantee that TBB will use a full mask worth of
# threads if it deems it inefficient to do so
if threading_layer() == 'tbb':
self.assertTrue(np.all(result <= expected))
elif threading_layer() in ('omp', 'workqueue'):
np.testing.assert_equal(expected, result)
else:
assert 0, 'unreachable'
@skip_parfors_unsupported
def test_set_num_threads_type(self):
@njit
def foo():
set_num_threads('wrong_type')
expected = "The number of threads specified must be an integer"
for fn, errty in ((foo, TypingError), (foo.py_func, TypeError)):
with self.assertRaises(errty) as raises:
fn()
self.assertIn(expected, str(raises.exception))
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_set_num_threads_basic(self):
max_threads = config.NUMBA_NUM_THREADS
self.assertEqual(get_num_threads(), max_threads)
set_num_threads(2)
self.assertEqual(get_num_threads(), 2)
set_num_threads(max_threads)
self.assertEqual(get_num_threads(), max_threads)
with self.assertRaises(ValueError):
set_num_threads(0)
with self.assertRaises(ValueError):
set_num_threads(max_threads + 1)
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_set_num_threads_basic_jit(self):
max_threads = config.NUMBA_NUM_THREADS
@njit
def get_n():
return get_num_threads()
self.assertEqual(get_n(), max_threads)
set_num_threads(2)
self.assertEqual(get_n(), 2)
set_num_threads(max_threads)
self.assertEqual(get_n(), max_threads)
@njit
def set_get_n(n):
set_num_threads(n)
return get_num_threads()
self.assertEqual(set_get_n(2), 2)
self.assertEqual(set_get_n(max_threads), max_threads)
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_set_num_threads_basic_guvectorize(self):
max_threads = config.NUMBA_NUM_THREADS
@guvectorize(['void(int64[:])'],
'(n)',
nopython=True,
target='parallel')
def get_n(x):
x[:] = get_num_threads()
x = np.zeros((5000000,), dtype=np.int64)
get_n(x)
np.testing.assert_equal(x, max_threads)
set_num_threads(2)
x = np.zeros((5000000,), dtype=np.int64)
get_n(x)
np.testing.assert_equal(x, 2)
set_num_threads(max_threads)
x = np.zeros((5000000,), dtype=np.int64)
get_n(x)
np.testing.assert_equal(x, max_threads)
@guvectorize(['void(int64[:])'],
'(n)',
nopython=True,
target='parallel')
def set_get_n(n):
set_num_threads(n[0])
n[:] = get_num_threads()
x = np.zeros((5000000,), dtype=np.int64)
x[0] = 2
set_get_n(x)
np.testing.assert_equal(x, 2)
x = np.zeros((5000000,), dtype=np.int64)
x[0] = max_threads
set_get_n(x)
np.testing.assert_equal(x, max_threads)
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_set_num_threads_outside_jit(self):
# Test set_num_threads outside a jitted function
set_num_threads(2)
@njit(parallel=True)
def test_func():
x = 5
buf = np.empty((x,))
for i in prange(x):
buf[i] = get_num_threads()
return buf
@guvectorize(['void(int64[:])'],
'(n)',
nopython=True,
target='parallel')
def test_gufunc(x):
x[:] = get_num_threads()
out = test_func()
np.testing.assert_equal(out, 2)
x = np.zeros((5000000,), dtype=np.int64)
test_gufunc(x)
np.testing.assert_equal(x, 2)
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_set_num_threads_inside_jit(self):
# Test set_num_threads inside a jitted function
@njit(parallel=True)
def test_func(nthreads):
x = 5
buf = np.empty((x,))
set_num_threads(nthreads)
for i in prange(x):
buf[i] = get_num_threads()
return buf
mask = 2
out = test_func(mask)
np.testing.assert_equal(out, mask)
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_set_num_threads_inside_guvectorize(self):
# Test set_num_threads inside a jitted guvectorize function
@guvectorize(['void(int64[:])'],
'(n)',
nopython=True,
target='parallel')
def test_func(x):
set_num_threads(x[0])
x[:] = get_num_threads()
x = np.zeros((5000000,), dtype=np.int64)
mask = 2
x[0] = mask
test_func(x)
np.testing.assert_equal(x, mask)
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_get_num_threads_truth_outside_jit(self):
for mask in range(2, min(6, config.NUMBA_NUM_THREADS + 1)):
set_num_threads(mask)
# a lot of work, hopefully will trigger "mask" count of threads to
# join the parallel region (for those backends with dynamic threads)
@njit(parallel=True)
def test_func():
x = 5000000
buf = np.empty((x,))
for i in prange(x):
buf[i] = get_thread_id()
return len(np.unique(buf)), get_num_threads()
out = test_func()
self.check_mask((mask, mask), out)
@guvectorize(['void(int64[:], int64[:])'],
'(n), (m)',
nopython=True,
target='parallel')
def test_gufunc(x, out):
x[:] = get_thread_id()
out[0] = get_num_threads()
# Reshape to force parallelism
x = np.full((5000000,), -1, dtype=np.int64).reshape((100, 50000))
out = np.zeros((1,), dtype=np.int64)
test_gufunc(x, out)
self.check_mask(mask, out)
self.check_mask(mask, len(np.unique(x)))
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_get_num_threads_truth_inside_jit(self):
for mask in range(2, min(6, config.NUMBA_NUM_THREADS + 1)):
# a lot of work, hopefully will trigger "mask" count of threads to
# join the parallel region (for those backends with dynamic threads)
@njit(parallel=True)
def test_func():
set_num_threads(mask)
x = 5000000
buf = np.empty((x,))
for i in prange(x):
buf[i] = get_thread_id()
return len(np.unique(buf)), get_num_threads()
out = test_func()
self.check_mask((mask, mask), out)
@guvectorize(['void(int64[:], int64[:])'],
'(n), (m)',
nopython=True,
target='parallel')
def test_gufunc(x, out):
set_num_threads(mask)
x[:] = get_thread_id()
out[0] = get_num_threads()
# Reshape to force parallelism
x = np.full((5000000,), -1, dtype=np.int64).reshape((100, 50000))
out = np.zeros((1,), dtype=np.int64)
test_gufunc(x, out)
self.check_mask(mask, out)
self.check_mask(mask, len(np.unique(x)))
# this test can only run on OpenMP (providing OMP_MAX_ACTIVE_LEVELS is not
# set or >= 2) and TBB backends
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_nested_parallelism_1(self):
if threading_layer() == 'workqueue':
self.skipTest("workqueue is not threadsafe")
# check that get_num_threads is ok in nesting
mask = config.NUMBA_NUM_THREADS - 1
N = config.NUMBA_NUM_THREADS
M = 2 * config.NUMBA_NUM_THREADS
@njit(parallel=True)
def child_func(buf, fid):
M, N = buf.shape
for i in prange(N):
buf[fid, i] = get_num_threads()
def get_test(test_type):
if test_type == 'njit':
def test_func(nthreads, py_func=False):
@njit(parallel=True)
def _test_func(nthreads):
acc = 0
buf = np.zeros((M, N))
set_num_threads(nthreads)
for i in prange(M):
local_mask = 1 + i % mask
# set threads in parent function
set_num_threads(local_mask)
if local_mask < N:
child_func(buf, local_mask)
acc += get_num_threads()
return acc, buf
if py_func:
return _test_func.py_func(nthreads)
else:
return _test_func(nthreads)
elif test_type == 'guvectorize':
def test_func(nthreads, py_func=False):
def _test_func(acc, buf, local_mask):
set_num_threads(nthreads)
# set threads in parent function
set_num_threads(local_mask[0])
if local_mask[0] < N:
child_func(buf, local_mask[0])
acc[0] += get_num_threads()
buf = np.zeros((M, N), dtype=np.int64)
acc = np.zeros((M, 1), dtype=np.int64)
local_mask = (1 + np.arange(M) % mask).reshape((M, 1))
sig = ['void(int64[:], int64[:, :], int64[:])']
layout = '(p), (n, m), (p)'
if not py_func:
_test_func = guvectorize(sig, layout, nopython=True,
target='parallel')(_test_func)
else:
_test_func = guvectorize(sig, layout,
forceobj=True)(_test_func)
_test_func(acc, buf, local_mask)
return acc, buf
return test_func
for test_type in ['njit', 'guvectorize']:
test_func = get_test(test_type)
got_acc, got_arr = test_func(mask)
exp_acc, exp_arr = test_func(mask, py_func=True)
np.testing.assert_equal(exp_acc, got_acc)
np.testing.assert_equal(exp_arr, got_arr)
# check the maths reconciles, guvectorize does not reduce, njit does
math_acc_exp = 1 + np.arange(M) % mask
if test_type == 'guvectorize':
math_acc = math_acc_exp.reshape((M, 1))
else:
math_acc = np.sum(math_acc_exp)
np.testing.assert_equal(math_acc, got_acc)
math_arr = np.zeros((M, N))
for i in range(1, N):
# there's branches on 1, ..., num_threads - 1
math_arr[i, :] = i
np.testing.assert_equal(math_arr, got_arr)
# this test can only run on OpenMP (providing OMP_MAX_ACTIVE_LEVELS is not
# set or >= 2) and TBB backends
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
def _test_nested_parallelism_2(self):
if threading_layer() == 'workqueue':
self.skipTest("workqueue is not threadsafe")
# check that get_num_threads is ok in nesting
N = config.NUMBA_NUM_THREADS + 1
M = 4 * config.NUMBA_NUM_THREADS + 1
def get_impl(child_type, test_type):
if child_type == 'parallel':
child_dec = njit(parallel=True)
elif child_type == 'njit':
child_dec = njit(parallel=False)
elif child_type == 'none':
def child_dec(x):
return x
@child_dec
def child(buf, fid):
M, N = buf.shape
set_num_threads(fid) # set threads in child function
for i in prange(N):
buf[fid, i] = get_num_threads()
if test_type in ['parallel', 'njit', 'none']:
if test_type == 'parallel':
test_dec = njit(parallel=True)
elif test_type == 'njit':
test_dec = njit(parallel=False)
elif test_type == 'none':
def test_dec(x):
return x
@test_dec
def test_func(nthreads):
buf = np.zeros((M, N))
set_num_threads(nthreads)
for i in prange(M):
local_mask = 1 + i % mask
# when the threads exit the child functions they should
# have a TLS slot value of the local mask as it was set
# in child
if local_mask < config.NUMBA_NUM_THREADS:
child(buf, local_mask)
assert get_num_threads() == local_mask
return buf
else:
if test_type == 'guvectorize':
test_dec = guvectorize(['int64[:,:], int64[:]'],
'(n, m), (k)', nopython=True,
target='parallel')
elif test_type == 'guvectorize-obj':
test_dec = guvectorize(['int64[:,:], int64[:]'],
'(n, m), (k)', forceobj=True)
def test_func(nthreads):
@test_dec
def _test_func(buf, local_mask):
set_num_threads(nthreads)
# when the threads exit the child functions they should
# have a TLS slot value of the local mask as it was set
# in child
if local_mask[0] < config.NUMBA_NUM_THREADS:
child(buf, local_mask[0])
assert get_num_threads() == local_mask[0]
buf = np.zeros((M, N), dtype=np.int64)
local_mask = (1 + np.arange(M) % mask).reshape((M, 1))
_test_func(buf, local_mask)
return buf
return test_func
mask = config.NUMBA_NUM_THREADS - 1
res_arrays = {}
for test_type in ['parallel', 'njit', 'none',
'guvectorize', 'guvectorize-obj']:
for child_type in ['parallel', 'njit', 'none']:
if child_type == 'none' and test_type != 'none':
continue
set_num_threads(mask)
res_arrays[test_type, child_type] = get_impl(
child_type, test_type)(mask)
py_arr = res_arrays['none', 'none']
for arr in res_arrays.values():
np.testing.assert_equal(arr, py_arr)
# check the maths reconciles
math_arr = np.zeros((M, N))
# there's branches on modulo mask but only NUMBA_NUM_THREADS funcs
for i in range(1, config.NUMBA_NUM_THREADS):
math_arr[i, :] = i
np.testing.assert_equal(math_arr, py_arr)
# this test can only run on OpenMP (providing OMP_MAX_ACTIVE_LEVELS is not
# set or >= 2) and TBB backends
# This test needs at least 3 threads to run, N>=2 for the launch, M>=N+1 for
# the nested function
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 3, "Not enough CPU cores")
def _test_nested_parallelism_3(self):
if threading_layer() == 'workqueue':
self.skipTest("workqueue is not threadsafe")
# check that the right number of threads are present in nesting
# this relies on there being a load of cores present
BIG = 1000000
@njit(parallel=True)
def work(local_nt): # arg is value 3
tid = np.zeros(BIG)
acc = 0
set_num_threads(local_nt) # set to 3 threads
for i in prange(BIG):
acc += 1
tid[i] = get_thread_id()
return acc, np.unique(tid)
@njit(parallel=True)
def test_func_jit(nthreads):
set_num_threads(nthreads) # set to 2 threads
lens = np.zeros(nthreads)
total = 0
for i in prange(nthreads):
my_acc, tids = work(nthreads + 1) # call with value 3
lens[i] = len(tids)
total += my_acc
return total, np.unique(lens)
NT = 2
expected_acc = BIG * NT
expected_thread_count = NT + 1
got_acc, got_tc = test_func_jit(NT)
self.assertEqual(expected_acc, got_acc)
self.check_mask(expected_thread_count, got_tc)
def test_guvectorize(nthreads):
@guvectorize(['int64[:], int64[:]'],
'(n), (n)',
nopython=True,
target='parallel')
def test_func_guvectorize(total, lens):
my_acc, tids = work(nthreads + 1)
lens[0] = len(tids)
total[0] += my_acc
total = np.zeros((nthreads, 1), dtype=np.int64)
lens = np.zeros(nthreads, dtype=np.int64).reshape((nthreads, 1))
test_func_guvectorize(total, lens)
# vectorize does not reduce, so total is summed
return total.sum(), np.unique(lens)
got_acc, got_tc = test_guvectorize(NT)
self.assertEqual(expected_acc, got_acc)
self.check_mask(expected_thread_count, got_tc)
@skip_parfors_unsupported
@unittest.skipIf(config.NUMBA_NUM_THREADS < 2, "Not enough CPU cores")
@unittest.skipIf(not sys.platform.startswith('linux'), "Linux only")
def _test_threadmask_across_fork(self):
forkctx = multiprocessing.get_context('fork')
@njit
def foo():
return get_num_threads()
def wrap(queue):
queue.put(foo())
mask = 1
self.assertEqual(foo(), config.NUMBA_NUM_THREADS)
set_num_threads(mask)
self.assertEqual(foo(), mask)
shared_queue = forkctx.Queue()
# check TLS slot inheritance in fork
p = forkctx.Process(target=wrap, args=(shared_queue,))
p.start()
p.join()
self.assertEqual(shared_queue.get(), mask)
def tearDown(self):
set_num_threads(config.NUMBA_NUM_THREADS)
@skip_parfors_unsupported
def _test_get_thread_id_not_parallel(self):
python_get_thread_id = get_thread_id()
check_array_size = 8
@njit(parallel=False)
def par_false(size):
njit_par_false_tid = get_thread_id()
res = np.ones(size)
for i in prange(size):
res[i] = get_thread_id()
return njit_par_false_tid, res
@njit(parallel=True)
def par_true(size):
njit_par_true_tid = get_thread_id()
res = np.ones(size)
for i in range(size):
res[i] = get_thread_id()
return njit_par_true_tid, res
self.assertEqual(python_get_thread_id, 0)
njit_par_false_tid, njit_par_false_arr = par_false(check_array_size)
self.assertEqual(njit_par_false_tid, 0)
np.testing.assert_equal(njit_par_false_arr, 0)
njit_par_true_tid, njit_par_true_arr = par_true(check_array_size)
self.assertEqual(njit_par_true_tid, 0)
np.testing.assert_equal(njit_par_true_arr, 0)
class TestNumThreadsBackends(TestInSubprocess, TestCase):
_class = TestNumThreads
_DEBUG = False
# 1 is mainly here to ensure tests skip correctly
num_threads = [i for i in [1, 2, 4, 8, 16] if i <= config.NUMBA_NUM_THREADS]
def run_test_in_separate_process(self, test, threading_layer, num_threads):
env_copy = os.environ.copy()
env_copy['NUMBA_THREADING_LAYER'] = str(threading_layer)
env_copy['NUMBA_NUM_THREADS'] = str(num_threads)
cmdline = [sys.executable, "-m", "numba.runtests", "-v", test]
return self.run_cmd(cmdline, env_copy)
@classmethod
def _inject(cls, name, backend, backend_guard, num_threads):
themod = cls.__module__
thecls = cls._class.__name__
injected_method = '%s.%s.%s' % (themod, thecls, name)
def test_template(self):
o, e = self.run_test_in_separate_process(injected_method, backend,
num_threads)
if self._DEBUG:
print('stdout:\n "%s"\n stderr:\n "%s"' % (o, e))
# If the test was skipped in the subprocess, then mark this as a
# skipped test.
m = re.search(r"\.\.\. skipped '(.*?)'", e)
if m is not None:
self.skipTest(m.group(1))
self.assertIn('OK', e)
self.assertTrue('FAIL' not in e)
self.assertTrue('ERROR' not in e)
injected_test = "%s_%s_%s_threads" % (name[1:], backend, num_threads)
setattr(cls, injected_test,
tag('long_running')(backend_guard(test_template)))
@classmethod
def generate(cls):
for name in cls._class.__dict__.copy():
for backend, backend_guard in cls.backends.items():
for num_threads in cls.num_threads:
if not name.startswith('_test_'):
continue
cls._inject(name, backend, backend_guard, num_threads)
TestNumThreadsBackends.generate()
if __name__ == '__main__':
unittest.main()