ai-content-maker/.venv/Lib/site-packages/numba/tests/test_func_lifetime.py

165 lines
4.8 KiB
Python

import gc
import weakref
from numba import jit
from numba.core import types
from numba.tests.support import TestCase
import unittest
class Dummy(object):
def __add__(self, other):
return other + 5
def global_usecase1(x):
return x + 1
def global_usecase2():
return global_obj + 1
class TestFuncLifetime(TestCase):
"""
Test the lifetime of compiled function objects and their dependencies.
"""
def get_impl(self, dispatcher):
"""
Get the single implementation (a C function object) of a dispatcher.
"""
self.assertEqual(len(dispatcher.overloads), 1)
cres = list(dispatcher.overloads.values())[0]
return cres.entry_point
def check_local_func_lifetime(self, **jitargs):
def f(x):
return x + 1
c_f = jit('int32(int32)', **jitargs)(f)
self.assertPreciseEqual(c_f(1), 2)
cfunc = self.get_impl(c_f)
# Since we can't take a weakref to a C function object
# (see http://bugs.python.org/issue22116), ensure it's
# collected by taking a weakref to its __self__ instead
# (a _dynfunc._Closure object).
refs = [weakref.ref(obj) for obj in (f, c_f, cfunc.__self__)]
obj = f = c_f = cfunc = None
gc.collect()
self.assertEqual([wr() for wr in refs], [None] * len(refs))
def test_local_func_lifetime(self):
self.check_local_func_lifetime(forceobj=True)
def test_local_func_lifetime_npm(self):
self.check_local_func_lifetime(nopython=True)
def check_global_func_lifetime(self, **jitargs):
c_f = jit(**jitargs)(global_usecase1)
self.assertPreciseEqual(c_f(1), 2)
cfunc = self.get_impl(c_f)
wr = weakref.ref(c_f)
refs = [weakref.ref(obj) for obj in (c_f, cfunc.__self__)]
obj = c_f = cfunc = None
gc.collect()
self.assertEqual([wr() for wr in refs], [None] * len(refs))
def test_global_func_lifetime(self):
self.check_global_func_lifetime(forceobj=True)
def test_global_func_lifetime_npm(self):
self.check_global_func_lifetime(nopython=True)
def check_global_obj_lifetime(self, **jitargs):
# Since global objects can be recorded for typing purposes,
# check that they are not kept around after they are removed
# from the globals.
global global_obj
global_obj = Dummy()
c_f = jit(**jitargs)(global_usecase2)
self.assertPreciseEqual(c_f(), 6)
refs = [weakref.ref(obj) for obj in (c_f, global_obj)]
obj = c_f = global_obj = None
gc.collect()
self.assertEqual([wr() for wr in refs], [None] * len(refs))
def test_global_obj_lifetime(self):
self.check_global_obj_lifetime(forceobj=True)
def check_inner_function_lifetime(self, **jitargs):
"""
When a jitted function calls into another jitted function, check
that everything is collected as desired.
"""
def mult_10(a):
return a * 10
c_mult_10 = jit('intp(intp)', **jitargs)(mult_10)
c_mult_10.disable_compile()
def do_math(x):
return c_mult_10(x + 4)
c_do_math = jit('intp(intp)', **jitargs)(do_math)
c_do_math.disable_compile()
self.assertEqual(c_do_math(1), 50)
wrs = [weakref.ref(obj) for obj in
(mult_10, c_mult_10, do_math, c_do_math,
self.get_impl(c_mult_10).__self__,
self.get_impl(c_do_math).__self__,
)]
obj = mult_10 = c_mult_10 = do_math = c_do_math = None
gc.collect()
self.assertEqual([w() for w in wrs], [None] * len(wrs))
def test_inner_function_lifetime(self):
self.check_inner_function_lifetime(forceobj=True)
def test_inner_function_lifetime_npm(self):
self.check_inner_function_lifetime(nopython=True)
class TestLifeTimeIssue(TestCase):
def test_double_free(self):
from numba import njit
import numpy as np
# This is the function that causes the crash
@njit
def is_point_in_polygons(point, polygons):
num_polygons = polygons.shape[0]
if num_polygons != 0:
# An extra decref is inserted in this block
intentionally_unused_variable = polygons[0]
return 0
# This function creates some NRT objects for the previous function
# to corrupt.
@njit
def dummy():
return np.empty(10, dtype=np.int64)
polygons = np.array([[[0, 1]]])
points = np.array([[-1.5, 0.5]])
a = dummy()
is_point_in_polygons(points[0], polygons)
b = dummy()
# Crash happens at second call
is_point_in_polygons(points[0], polygons)
c = dummy()
if __name__ == '__main__':
unittest.main()