150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
import asyncio
|
|
import gc
|
|
import shutil
|
|
|
|
import pytest
|
|
|
|
from joblib.memory import (AsyncMemorizedFunc, AsyncNotMemorizedFunc,
|
|
MemorizedResult, Memory, NotMemorizedResult)
|
|
from joblib.test.common import np, with_numpy
|
|
from joblib.testing import raises
|
|
|
|
from .test_memory import (corrupt_single_cache_item,
|
|
monkeypatch_cached_func_warn)
|
|
|
|
|
|
async def check_identity_lazy_async(func, accumulator, location):
|
|
""" Similar to check_identity_lazy_async for coroutine functions"""
|
|
memory = Memory(location=location, verbose=0)
|
|
func = memory.cache(func)
|
|
for i in range(3):
|
|
for _ in range(2):
|
|
value = await func(i)
|
|
assert value == i
|
|
assert len(accumulator) == i + 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_integration_async(tmpdir):
|
|
accumulator = list()
|
|
|
|
async def f(n):
|
|
await asyncio.sleep(0.1)
|
|
accumulator.append(1)
|
|
return n
|
|
|
|
await check_identity_lazy_async(f, accumulator, tmpdir.strpath)
|
|
|
|
# Now test clearing
|
|
for compress in (False, True):
|
|
for mmap_mode in ('r', None):
|
|
memory = Memory(location=tmpdir.strpath, verbose=10,
|
|
mmap_mode=mmap_mode, compress=compress)
|
|
# First clear the cache directory, to check that our code can
|
|
# handle that
|
|
# NOTE: this line would raise an exception, as the database
|
|
# file is still open; we ignore the error since we want to
|
|
# test what happens if the directory disappears
|
|
shutil.rmtree(tmpdir.strpath, ignore_errors=True)
|
|
g = memory.cache(f)
|
|
await g(1)
|
|
g.clear(warn=False)
|
|
current_accumulator = len(accumulator)
|
|
out = await g(1)
|
|
|
|
assert len(accumulator) == current_accumulator + 1
|
|
# Also, check that Memory.eval works similarly
|
|
evaled = await memory.eval(f, 1)
|
|
assert evaled == out
|
|
assert len(accumulator) == current_accumulator + 1
|
|
|
|
# Now do a smoke test with a function defined in __main__, as the name
|
|
# mangling rules are more complex
|
|
f.__module__ = '__main__'
|
|
memory = Memory(location=tmpdir.strpath, verbose=0)
|
|
await memory.cache(f)(1)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_memory_async():
|
|
accumulator = list()
|
|
|
|
async def ff(x):
|
|
await asyncio.sleep(0.1)
|
|
accumulator.append(1)
|
|
return x
|
|
|
|
memory = Memory(location=None, verbose=0)
|
|
gg = memory.cache(ff)
|
|
for _ in range(4):
|
|
current_accumulator = len(accumulator)
|
|
await gg(1)
|
|
assert len(accumulator) == current_accumulator + 1
|
|
|
|
|
|
@with_numpy
|
|
@pytest.mark.asyncio
|
|
async def test_memory_numpy_check_mmap_mode_async(tmpdir, monkeypatch):
|
|
"""Check that mmap_mode is respected even at the first call"""
|
|
|
|
memory = Memory(location=tmpdir.strpath, mmap_mode='r', verbose=0)
|
|
|
|
@memory.cache()
|
|
async def twice(a):
|
|
return a * 2
|
|
|
|
a = np.ones(3)
|
|
b = await twice(a)
|
|
c = await twice(a)
|
|
|
|
assert isinstance(c, np.memmap)
|
|
assert c.mode == 'r'
|
|
|
|
assert isinstance(b, np.memmap)
|
|
assert b.mode == 'r'
|
|
|
|
# Corrupts the file, Deleting b and c mmaps
|
|
# is necessary to be able edit the file
|
|
del b
|
|
del c
|
|
gc.collect()
|
|
corrupt_single_cache_item(memory)
|
|
|
|
# Make sure that corrupting the file causes recomputation and that
|
|
# a warning is issued.
|
|
recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
|
|
d = await twice(a)
|
|
assert len(recorded_warnings) == 1
|
|
exception_msg = 'Exception while loading results'
|
|
assert exception_msg in recorded_warnings[0]
|
|
# Asserts that the recomputation returns a mmap
|
|
assert isinstance(d, np.memmap)
|
|
assert d.mode == 'r'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_and_shelve_async(tmpdir):
|
|
async def f(x, y=1):
|
|
await asyncio.sleep(0.1)
|
|
return x ** 2 + y
|
|
|
|
# Test MemorizedFunc outputting a reference to cache.
|
|
for func, Result in zip((AsyncMemorizedFunc(f, tmpdir.strpath),
|
|
AsyncNotMemorizedFunc(f),
|
|
Memory(location=tmpdir.strpath,
|
|
verbose=0).cache(f),
|
|
Memory(location=None).cache(f),
|
|
),
|
|
(MemorizedResult, NotMemorizedResult,
|
|
MemorizedResult, NotMemorizedResult,
|
|
)):
|
|
for _ in range(2):
|
|
result = await func.call_and_shelve(2)
|
|
assert isinstance(result, Result)
|
|
assert result.get() == 5
|
|
|
|
result.clear()
|
|
with raises(KeyError):
|
|
result.get()
|
|
result.clear() # Do nothing if there is no cache.
|