47 lines
2.2 KiB
Python
47 lines
2.2 KiB
Python
|
from typing import cast
|
||
|
|
||
|
from ..compat import cupy, tensorflow, torch
|
||
|
from ..types import ArrayXd
|
||
|
from ..util import get_torch_default_device, tensorflow2xp
|
||
|
|
||
|
|
||
|
def cupy_tensorflow_allocator(size_in_bytes: int):
|
||
|
"""Function that can be passed into cupy.cuda.set_allocator, to have cupy
|
||
|
allocate memory via TensorFlow. This is important when using the two libraries
|
||
|
together, as otherwise OOM errors can occur when there's available memory
|
||
|
sitting in the other library's pool.
|
||
|
"""
|
||
|
size_in_bytes = max(1024, size_in_bytes)
|
||
|
tensor = tensorflow.zeros((size_in_bytes // 4,), dtype=tensorflow.dtypes.float32) # type: ignore
|
||
|
# We convert to cupy via dlpack, so that we can get a memory pointer.
|
||
|
cupy_array = cast(ArrayXd, tensorflow2xp(tensor))
|
||
|
address = int(cupy_array.data)
|
||
|
# cupy has a neat class to help us here. Otherwise it will try to free.
|
||
|
memory = cupy.cuda.memory.UnownedMemory(address, size_in_bytes, cupy_array)
|
||
|
# Now return a new memory pointer.
|
||
|
return cupy.cuda.memory.MemoryPointer(memory, 0)
|
||
|
|
||
|
|
||
|
def cupy_pytorch_allocator(size_in_bytes: int):
|
||
|
device = get_torch_default_device()
|
||
|
"""Function that can be passed into cupy.cuda.set_allocator, to have cupy
|
||
|
allocate memory via PyTorch. This is important when using the two libraries
|
||
|
together, as otherwise OOM errors can occur when there's available memory
|
||
|
sitting in the other library's pool.
|
||
|
"""
|
||
|
# Cupy was having trouble with very small allocations?
|
||
|
size_in_bytes = max(1024, size_in_bytes)
|
||
|
# We use pytorch's underlying FloatStorage type to avoid overhead from
|
||
|
# creating a whole Tensor.
|
||
|
# This turns out to be way faster than making FloatStorage? Maybe
|
||
|
# a Python vs C++ thing I guess?
|
||
|
torch_tensor = torch.zeros(
|
||
|
(size_in_bytes // 4,), requires_grad=False, device=device
|
||
|
)
|
||
|
# cupy has a neat class to help us here. Otherwise it will try to free.
|
||
|
# I think this is a private API? It's not in the types.
|
||
|
address = torch_tensor.data_ptr() # type: ignore
|
||
|
memory = cupy.cuda.memory.UnownedMemory(address, size_in_bytes, torch_tensor)
|
||
|
# Now return a new memory pointer.
|
||
|
return cupy.cuda.memory.MemoryPointer(memory, 0)
|