99 lines
2.6 KiB
Python
99 lines
2.6 KiB
Python
"""
|
|
credit: https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
|
|
|
|
Helper to free Torch cuda memory and determine when a Torch exception might be
|
|
because of OOM conditions.
|
|
"""
|
|
from __future__ import print_function
|
|
|
|
import gc
|
|
|
|
import torch
|
|
|
|
from trainer.utils.cpu_memory import is_out_of_cpu_memory
|
|
|
|
|
|
def gc_cuda():
|
|
"""Gargage collect Torch (CUDA) memory."""
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def get_cuda_total_memory():
|
|
if torch.cuda.is_available():
|
|
return torch.cuda.get_device_properties(0).total_memory
|
|
return 0
|
|
|
|
|
|
def get_cuda_assumed_available_memory():
|
|
if torch.cuda.is_available():
|
|
return get_cuda_total_memory() - torch.cuda.memory_reserved()
|
|
return 0
|
|
|
|
|
|
def get_cuda_available_memory():
|
|
# Always allow for 1 GB overhead.
|
|
if torch.cuda.is_available():
|
|
return get_cuda_assumed_available_memory() - get_cuda_blocked_memory()
|
|
return 0
|
|
|
|
|
|
def get_cuda_blocked_memory():
|
|
if not torch.cuda.is_available():
|
|
return 0
|
|
|
|
available_memory = get_cuda_assumed_available_memory()
|
|
current_block = available_memory - 2**28 # 256 MB steps
|
|
while True:
|
|
try:
|
|
_ = torch.empty((current_block,), dtype=torch.uint8, device="cuda")
|
|
break
|
|
except RuntimeError as exception:
|
|
if is_cuda_out_of_memory(exception):
|
|
current_block -= 2**30
|
|
if current_block <= 0:
|
|
return available_memory
|
|
else:
|
|
raise
|
|
_ = None
|
|
gc_cuda()
|
|
return available_memory - current_block
|
|
|
|
|
|
def is_cuda_out_of_memory(exception):
|
|
return (
|
|
isinstance(exception, (RuntimeError, torch.cuda.OutOfMemoryError))
|
|
and len(exception.args) == 1
|
|
and "CUDA out of memory." in exception.args[0]
|
|
)
|
|
|
|
|
|
def is_cudnn_snafu(exception):
|
|
# For/because of https://github.com/pytorch/pytorch/issues/4107
|
|
return (
|
|
isinstance(exception, RuntimeError)
|
|
and len(exception.args) == 1
|
|
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
|
|
)
|
|
|
|
|
|
def cuda_meminfo():
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
print(
|
|
"Total:", torch.cuda.memory_allocated() / 2**30, " GB Cached: ", torch.cuda.memory_reserved() / 2**30, "GB"
|
|
)
|
|
print(
|
|
"Max Total:",
|
|
torch.cuda.max_memory_allocated() / 2**30,
|
|
" GB Max Cached: ",
|
|
torch.cuda.max_memory_reserved() / 2**30,
|
|
"GB",
|
|
)
|
|
|
|
|
|
def should_reduce_batch_size(exception):
|
|
return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception)
|