118 lines
4.3 KiB
Python
118 lines
4.3 KiB
Python
import numpy as np
|
|
import torch
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
|
|
class DistributedSamplerWrapper(DistributedSampler):
|
|
"""Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode.
|
|
It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each
|
|
process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler,
|
|
and load a subset of the original dataset that is exclusive to it.
|
|
|
|
.. note:
|
|
Dataset is assumed to be of constant size.
|
|
|
|
Args:
|
|
sampler: Sampler used for subsampling.
|
|
num_replicas (int, optional): Number of processes participating in distributed training. By default,
|
|
world_size is retrieved from the current distributed group.
|
|
rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved
|
|
from the current distributed group.
|
|
shuffle (bool, optional): If True, sampler will shuffle the indices. Default: True.
|
|
seed (int, optional): random seed used to shuffle the sampler if shuffle=True. This number should be
|
|
identical across all processes in the distributed group. Default: 0.
|
|
|
|
Reference: https://github.com/pytorch/pytorch/issues/23430
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sampler,
|
|
num_replicas: int = None,
|
|
rank: int = None,
|
|
shuffle: bool = True,
|
|
seed: int = 0,
|
|
):
|
|
super().__init__(
|
|
sampler,
|
|
num_replicas=num_replicas,
|
|
rank=rank,
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
)
|
|
|
|
def __iter__(self):
|
|
indices = list(self.dataset)[: self.total_size]
|
|
|
|
# Add extra samples to make it evenly divisible
|
|
indices += indices[: (self.total_size - len(indices))]
|
|
assert len(indices) == self.total_size, f"{len(indices)} != {self.total_size}"
|
|
|
|
# Subsample
|
|
offset = self.num_samples * self.rank
|
|
indices = indices[offset : offset + self.num_samples]
|
|
assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}"
|
|
|
|
return iter(indices)
|
|
|
|
def set_epoch(self, epoch):
|
|
super().set_epoch(epoch)
|
|
if hasattr(self.dataset, "set_epoch"):
|
|
self.dataset.set_epoch(epoch)
|
|
elif hasattr(self.dataset, "generator"):
|
|
self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch)
|
|
|
|
def state_dict(self):
|
|
return self.dataset.state_dict()
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.dataset.load_state_dict(state_dict)
|
|
|
|
|
|
# pylint: disable=protected-access
|
|
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
|
self.warmup_steps = float(warmup_steps)
|
|
super().__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
step = max(self.last_epoch, 1)
|
|
return [
|
|
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
|
|
for base_lr in self.base_lrs
|
|
]
|
|
|
|
|
|
# pylint: disable=protected-access
|
|
class StepwiseGradualLR(torch.optim.lr_scheduler._LRScheduler):
|
|
"""Hardcoded step-wise learning rate scheduling.
|
|
Necessary for CapacitronVAE"""
|
|
|
|
def __init__(self, optimizer, gradual_learning_rates, last_epoch=-1):
|
|
self.gradual_learning_rates = gradual_learning_rates
|
|
super().__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
step = max(self.last_epoch, 1)
|
|
step_thresholds = []
|
|
rates = []
|
|
for values in self.gradual_learning_rates:
|
|
step_thresholds.append(values[0])
|
|
rates.append(values[1])
|
|
|
|
boolean_indeces = np.less_equal(step_thresholds, step)
|
|
try:
|
|
last_true = np.where(boolean_indeces == True)[0][-1] # pylint: disable=singleton-comparison
|
|
except IndexError:
|
|
# For the steps larger than the last step in the list
|
|
pass
|
|
lr = rates[np.max(last_true, 0)]
|
|
|
|
# Return last lr if step is above the set threshold
|
|
lr = rates[-1] if step > step_thresholds[-1] else lr
|
|
# Return first lr if step is below the second threshold - first is initial lr
|
|
lr = rates[0] if step < step_thresholds[1] else lr
|
|
|
|
return np.tile(lr, len(self.base_lrs)) # hack?
|