ai-content-maker/.venv/Lib/site-packages/trainer/torch.py

118 lines
4.3 KiB
Python

import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
class DistributedSamplerWrapper(DistributedSampler):
"""Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode.
It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each
process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note:
Dataset is assumed to be of constant size.
Args:
sampler: Sampler used for subsampling.
num_replicas (int, optional): Number of processes participating in distributed training. By default,
world_size is retrieved from the current distributed group.
rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved
from the current distributed group.
shuffle (bool, optional): If True, sampler will shuffle the indices. Default: True.
seed (int, optional): random seed used to shuffle the sampler if shuffle=True. This number should be
identical across all processes in the distributed group. Default: 0.
Reference: https://github.com/pytorch/pytorch/issues/23430
"""
def __init__(
self,
sampler,
num_replicas: int = None,
rank: int = None,
shuffle: bool = True,
seed: int = 0,
):
super().__init__(
sampler,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
)
def __iter__(self):
indices = list(self.dataset)[: self.total_size]
# Add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size, f"{len(indices)} != {self.total_size}"
# Subsample
offset = self.num_samples * self.rank
indices = indices[offset : offset + self.num_samples]
assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}"
return iter(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
elif hasattr(self.dataset, "generator"):
self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch)
def state_dict(self):
return self.dataset.state_dict()
def load_state_dict(self, state_dict):
self.dataset.load_state_dict(state_dict)
# pylint: disable=protected-access
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps)
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
return [
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
# pylint: disable=protected-access
class StepwiseGradualLR(torch.optim.lr_scheduler._LRScheduler):
"""Hardcoded step-wise learning rate scheduling.
Necessary for CapacitronVAE"""
def __init__(self, optimizer, gradual_learning_rates, last_epoch=-1):
self.gradual_learning_rates = gradual_learning_rates
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
step_thresholds = []
rates = []
for values in self.gradual_learning_rates:
step_thresholds.append(values[0])
rates.append(values[1])
boolean_indeces = np.less_equal(step_thresholds, step)
try:
last_true = np.where(boolean_indeces == True)[0][-1] # pylint: disable=singleton-comparison
except IndexError:
# For the steps larger than the last step in the list
pass
lr = rates[np.max(last_true, 0)]
# Return last lr if step is above the set threshold
lr = rates[-1] if step > step_thresholds[-1] else lr
# Return first lr if step is below the second threshold - first is initial lr
lr = rates[0] if step < step_thresholds[1] else lr
return np.tile(lr, len(self.base_lrs)) # hack?