""" credit: https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py Helper to free Torch cuda memory and determine when a Torch exception might be because of OOM conditions. """ from __future__ import print_function import gc import torch from trainer.utils.cpu_memory import is_out_of_cpu_memory def gc_cuda(): """Gargage collect Torch (CUDA) memory.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def get_cuda_total_memory(): if torch.cuda.is_available(): return torch.cuda.get_device_properties(0).total_memory return 0 def get_cuda_assumed_available_memory(): if torch.cuda.is_available(): return get_cuda_total_memory() - torch.cuda.memory_reserved() return 0 def get_cuda_available_memory(): # Always allow for 1 GB overhead. if torch.cuda.is_available(): return get_cuda_assumed_available_memory() - get_cuda_blocked_memory() return 0 def get_cuda_blocked_memory(): if not torch.cuda.is_available(): return 0 available_memory = get_cuda_assumed_available_memory() current_block = available_memory - 2**28 # 256 MB steps while True: try: _ = torch.empty((current_block,), dtype=torch.uint8, device="cuda") break except RuntimeError as exception: if is_cuda_out_of_memory(exception): current_block -= 2**30 if current_block <= 0: return available_memory else: raise _ = None gc_cuda() return available_memory - current_block def is_cuda_out_of_memory(exception): return ( isinstance(exception, (RuntimeError, torch.cuda.OutOfMemoryError)) and len(exception.args) == 1 and "CUDA out of memory." in exception.args[0] ) def is_cudnn_snafu(exception): # For/because of https://github.com/pytorch/pytorch/issues/4107 return ( isinstance(exception, RuntimeError) and len(exception.args) == 1 and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] ) def cuda_meminfo(): if not torch.cuda.is_available(): return print( "Total:", torch.cuda.memory_allocated() / 2**30, " GB Cached: ", torch.cuda.memory_reserved() / 2**30, "GB" ) print( "Max Total:", torch.cuda.max_memory_allocated() / 2**30, " GB Max Cached: ", torch.cuda.max_memory_reserved() / 2**30, "GB", ) def should_reduce_batch_size(exception): return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception)