# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch optimization for BERT model.""" import math import warnings from functools import partial from typing import Callable, Iterable, Optional, Tuple, Union import torch from torch import nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler from .trainer_utils import SchedulerType from .utils import logging from .utils.versions import require_version logger = logging.get_logger(__name__) def _get_constant_lambda(_=None): return 1 def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch) def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs): """ Create a schedule with a constant learning rate that decreases when a metric has stopped improving. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. kwargs (`dict`, *optional*): Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau` for possible parameters. Return: `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule. """ return ReduceLROnPlateau(optimizer, **kwargs) def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): """ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps) return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_lambda = partial( _get_linear_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) return LambdaLR(optimizer, lr_lambda, last_epoch) def _get_cosine_schedule_with_warmup_lr_lambda( current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float ): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) def get_cosine_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`float`, *optional*, defaults to 0.5): The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_lambda = partial( _get_cosine_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, ) return LambdaLR(optimizer, lr_lambda, last_epoch) def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda( current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int ): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) def get_cosine_with_hard_restarts_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`int`, *optional*, defaults to 1): The number of hard restarts to use. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_lambda = partial( _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, ) return LambdaLR(optimizer, lr_lambda, last_epoch) def _get_polynomial_decay_schedule_with_warmup_lr_lambda( current_step: int, *, num_warmup_steps: int, num_training_steps: int, lr_end: float, power: float, lr_init: int, ): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: return lr_end / lr_init # as LambdaLR multiplies by lr_init else: lr_range = lr_init - lr_end decay_steps = num_training_steps - num_warmup_steps pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps decay = lr_range * pct_remaining**power + lr_end return decay / lr_init # as LambdaLR multiplies by lr_init def get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 ): """ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. lr_end (`float`, *optional*, defaults to 1e-7): The end LR. power (`float`, *optional*, defaults to 1.0): Power factor. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_init = optimizer.defaults["lr"] if not (lr_init > lr_end): raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") lr_lambda = partial( _get_polynomial_decay_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, lr_end=lr_end, power=power, lr_init=lr_init, ) return LambdaLR(optimizer, lr_lambda, last_epoch) def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) shift = timescale - num_warmup_steps decay = 1.0 / math.sqrt((current_step + shift) / timescale) return decay def get_inverse_sqrt_schedule( optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1 ): """ Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a warmup period which increases lr linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. timescale (`int`, *optional*, defaults to `num_warmup_steps`): Time scale. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ # Note: this implementation is adapted from # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930 if timescale is None: timescale = num_warmup_steps or 10_000 lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale) return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) def _get_cosine_schedule_with_warmup_lr_lambda( current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 ): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) factor = factor * (1 - min_lr_rate) + min_lr_rate return max(0, factor) def get_cosine_with_min_lr_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, min_lr: float = None, min_lr_rate: float = None, ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`float`, *optional*, defaults to 0.5): The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. min_lr (`float`, *optional*): The minimum learning rate to reach after the cosine schedule. min_lr_rate (`float`, *optional*): The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ if min_lr is not None and min_lr_rate is not None: raise ValueError("Only one of min_lr or min_lr_rate should be set") elif min_lr is not None: min_lr_rate = min_lr / optimizer.defaults["lr"] elif min_lr_rate is None: raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`") lr_lambda = partial( _get_cosine_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, min_lr_rate=min_lr_rate, ) return LambdaLR(optimizer, lr_lambda, last_epoch) TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.LINEAR: get_linear_schedule_with_warmup, SchedulerType.COSINE: get_cosine_schedule_with_warmup, SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup, } def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, scheduler_specific_kwargs: Optional[dict] = None, ): """ Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_training_steps (`int``, *optional*): The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. scheduler_specific_kwargs (`dict`, *optional*): Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler parameters will cause the scheduler function to raise a TypeError. """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and # recursively call `get_scheduler` to get the proper schedulers on each parameter if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): optimizer_dict = optimizer.optimizer_dict scheduler_dict = {} for param in optimizer_dict.keys(): scheduler_dict[param] = get_scheduler( name, optimizer=optimizer_dict[param], num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) def scheduler_hook(param): # Since the optimizer hook has been already attached we only need to # attach the scheduler hook if param.grad is not None: scheduler_dict[param].step() for param in optimizer_dict.keys(): if param.requires_grad: param.register_post_accumulate_grad_hook(scheduler_hook) return LayerWiseDummyScheduler() if name == SchedulerType.CONSTANT: return schedule_func(optimizer) if scheduler_specific_kwargs is None: scheduler_specific_kwargs = {} if name == SchedulerType.REDUCE_ON_PLATEAU: return schedule_func(optimizer, **scheduler_specific_kwargs) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) if name == SchedulerType.INVERSE_SQRT: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **scheduler_specific_kwargs, ) class AdamW(Optimizer): """ Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). Parameters: params (`Iterable[nn.parameter.Parameter]`): Iterable of parameters to optimize or dictionaries defining parameter groups. lr (`float`, *optional*, defaults to 0.001): The learning rate to use. betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): Adam's betas parameters (b1, b2). eps (`float`, *optional*, defaults to 1e-06): Adam's epsilon for numerical stability. weight_decay (`float`, *optional*, defaults to 0.0): Decoupled weight decay to apply. correct_bias (`bool`, *optional*, defaults to `True`): Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). no_deprecation_warning (`bool`, *optional*, defaults to `False`): A flag used to disable the deprecation warning (set to `True` to disable the warning). """ def __init__( self, params: Iterable[nn.parameter.Parameter], lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-6, weight_decay: float = 0.0, correct_bias: bool = True, no_deprecation_warning: bool = False, ): if not no_deprecation_warning: warnings.warn( "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" " warning", FutureWarning, ) require_version("torch>=1.5.0") # add_ with alpha if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} super().__init__(params, defaults) @torch.no_grad() def step(self, closure: Callable = None): """ Performs a single optimization step. Arguments: closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.is_sparse: raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) step_size = group["lr"] if group["correct_bias"]: # No bias correction for Bert bias_correction1 = 1.0 - beta1 ** state["step"] bias_correction2 = 1.0 - beta2 ** state["step"] step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 p.addcdiv_(exp_avg, denom, value=-step_size) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. # Add weight decay at the end (fixed version) if group["weight_decay"] > 0.0: p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) return loss class Adafactor(Optimizer): """ AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and `relative_step=False`. Arguments: params (`Iterable[nn.parameter.Parameter]`): Iterable of parameters to optimize or dictionaries defining parameter groups. lr (`float`, *optional*): The external learning rate. eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): Regularization constants for square gradient and parameter scale respectively clip_threshold (`float`, *optional*, defaults to 1.0): Threshold of root mean square of final gradient update decay_rate (`float`, *optional*, defaults to -0.8): Coefficient used to compute running averages of square beta1 (`float`, *optional*): Coefficient used for computing running averages of gradient weight_decay (`float`, *optional*, defaults to 0.0): Weight decay (L2 penalty) scale_parameter (`bool`, *optional*, defaults to `True`): If True, learning rate is scaled by root mean square relative_step (`bool`, *optional*, defaults to `True`): If True, time-dependent learning rate is computed instead of external learning rate warmup_init (`bool`, *optional*, defaults to `False`): Time-dependent learning rate computation depends on whether warm-up initialization is being used This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): - Training without LR warmup or clip_threshold is not recommended. - use scheduled LR warm-up to fixed LR - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) - Disable relative updates - Use scale_parameter=False - Additional optimizer operations like gradient clipping should not be used alongside Adafactor Example: ```python Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) ``` Others reported the following combination to work well: ```python Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) ``` When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] scheduler as following: ```python from transformers.optimization import Adafactor, AdafactorSchedule optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) lr_scheduler = AdafactorSchedule(optimizer) trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) ``` Usage: ```python # replace AdamW with Adafactor optimizer = Adafactor( model.parameters(), lr=1e-3, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, relative_step=False, scale_parameter=False, warmup_init=False, ) ```""" def __init__( self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True, warmup_init=False, ): require_version("torch>=1.5.0") # add_ with alpha if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: raise ValueError("`warmup_init=True` requires `relative_step=True`") defaults = { "lr": lr, "eps": eps, "clip_threshold": clip_threshold, "decay_rate": decay_rate, "beta1": beta1, "weight_decay": weight_decay, "scale_parameter": scale_parameter, "relative_step": relative_step, "warmup_init": warmup_init, } super().__init__(params, defaults) @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] if param_group["relative_step"]: min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) param_scale = 1.0 if param_group["scale_parameter"]: param_scale = max(param_group["eps"][1], param_state["RMS"]) return param_scale * rel_step_sz @staticmethod def _get_options(param_group, param_shape): factored = len(param_shape) >= 2 use_first_moment = param_group["beta1"] is not None return factored, use_first_moment @staticmethod def _rms(tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) @staticmethod def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): # copy from fairseq's adafactor implementation: # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) @torch.no_grad() def step(self, closure=None): """ Performs a single optimization step Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") state = self.state[p] grad_shape = grad.shape factored, use_first_moment = self._get_options(group, grad_shape) # State Initialization if len(state) == 0: state["step"] = 0 if use_first_moment: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad) if factored: state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) else: state["exp_avg_sq"] = torch.zeros_like(grad) state["RMS"] = 0 else: if use_first_moment: state["exp_avg"] = state["exp_avg"].to(grad) if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p if p.dtype in {torch.float16, torch.bfloat16}: p_data_fp32 = p_data_fp32.float() state["step"] += 1 state["RMS"] = self._rms(p_data_fp32) lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) # Approximation of exponential moving average of square of gradient update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) update = exp_avg if group["weight_decay"] != 0: p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) p_data_fp32.add_(-update) if p.dtype in {torch.float16, torch.bfloat16}: p.copy_(p_data_fp32) return loss class AdafactorSchedule(LambdaLR): """ Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g., for logging), this class creates a proxy object that retrieves the current lr values from the optimizer. It returns `initial_lr` during startup and the actual `lr` during stepping. """ def __init__(self, optimizer, initial_lr=0.0): def lr_lambda(_): return initial_lr for group in optimizer.param_groups: group["initial_lr"] = initial_lr super().__init__(optimizer, lr_lambda) for group in optimizer.param_groups: del group["initial_lr"] def get_lr(self): opt = self.optimizer lrs = [ opt._get_lr(group, opt.state[group["params"][0]]) for group in opt.param_groups if group["params"][0].grad is not None ] if len(lrs) == 0: lrs = self.base_lrs # if called before stepping return lrs def get_adafactor_schedule(optimizer, initial_lr=0.0): """ Get a proxy schedule for [`~optimization.Adafactor`] Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. initial_lr (`float`, *optional*, defaults to 0.0): Initial lr Return: [`~optimization.Adafactor`] proxy schedule object. """ return AdafactorSchedule(optimizer, initial_lr)