45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||
|
r"""Check model gradient against unexpected jumps and failures"""
|
||
|
skip_flag = False
|
||
|
if ignore_stopnet:
|
||
|
if not amp_opt_params:
|
||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||
|
[param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip
|
||
|
)
|
||
|
else:
|
||
|
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||
|
else:
|
||
|
if not amp_opt_params:
|
||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||
|
else:
|
||
|
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||
|
|
||
|
# compatibility with different torch versions
|
||
|
if isinstance(grad_norm, float):
|
||
|
if np.isinf(grad_norm):
|
||
|
print(" | > Gradient is INF !!")
|
||
|
skip_flag = True
|
||
|
else:
|
||
|
if torch.isinf(grad_norm):
|
||
|
print(" | > Gradient is INF !!")
|
||
|
skip_flag = True
|
||
|
return grad_norm, skip_flag
|
||
|
|
||
|
|
||
|
def gradual_training_scheduler(global_step, config):
|
||
|
"""Setup the gradual training schedule wrt number
|
||
|
of active GPUs"""
|
||
|
num_gpus = torch.cuda.device_count()
|
||
|
if num_gpus == 0:
|
||
|
num_gpus = 1
|
||
|
new_values = None
|
||
|
# we set the scheduling wrt num_gpus
|
||
|
for values in config.gradual_training:
|
||
|
if global_step * num_gpus >= values[0]:
|
||
|
new_values = values
|
||
|
return new_values[1], new_values[2]
|