ai-content-maker/.venv/Lib/site-packages/trainer/utils/cuda_memory.py

99 lines
2.6 KiB
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
"""
credit: https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
Helper to free Torch cuda memory and determine when a Torch exception might be
because of OOM conditions.
"""
from __future__ import print_function
import gc
import torch
from trainer.utils.cpu_memory import is_out_of_cpu_memory
def gc_cuda():
"""Gargage collect Torch (CUDA) memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_cuda_total_memory():
if torch.cuda.is_available():
return torch.cuda.get_device_properties(0).total_memory
return 0
def get_cuda_assumed_available_memory():
if torch.cuda.is_available():
return get_cuda_total_memory() - torch.cuda.memory_reserved()
return 0
def get_cuda_available_memory():
# Always allow for 1 GB overhead.
if torch.cuda.is_available():
return get_cuda_assumed_available_memory() - get_cuda_blocked_memory()
return 0
def get_cuda_blocked_memory():
if not torch.cuda.is_available():
return 0
available_memory = get_cuda_assumed_available_memory()
current_block = available_memory - 2**28 # 256 MB steps
while True:
try:
_ = torch.empty((current_block,), dtype=torch.uint8, device="cuda")
break
except RuntimeError as exception:
if is_cuda_out_of_memory(exception):
current_block -= 2**30
if current_block <= 0:
return available_memory
else:
raise
_ = None
gc_cuda()
return available_memory - current_block
def is_cuda_out_of_memory(exception):
return (
isinstance(exception, (RuntimeError, torch.cuda.OutOfMemoryError))
and len(exception.args) == 1
and "CUDA out of memory." in exception.args[0]
)
def is_cudnn_snafu(exception):
# For/because of https://github.com/pytorch/pytorch/issues/4107
return (
isinstance(exception, RuntimeError)
and len(exception.args) == 1
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
)
def cuda_meminfo():
if not torch.cuda.is_available():
return
print(
"Total:", torch.cuda.memory_allocated() / 2**30, " GB Cached: ", torch.cuda.memory_reserved() / 2**30, "GB"
)
print(
"Max Total:",
torch.cuda.max_memory_allocated() / 2**30,
" GB Max Cached: ",
torch.cuda.max_memory_reserved() / 2**30,
"GB",
)
def should_reduce_batch_size(exception):
return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception)