ai-content-maker/.venv/Lib/site-packages/thinc/optimizers.py

356 lines
12 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import math
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union, cast
from .backends import get_array_ops
from .config import registry
from .types import FloatsXd, Generator
KeyT = Tuple[int, str]
FloatOrSeq = Union[float, List[float], Generator]
IntOrSeq = Union[int, List[int], Generator]
SGD_DEFAULTS: Dict[str, Union[float, bool, int]] = {
"L2": 0.0,
"L2_is_weight_decay": True,
"grad_clip": 1.0,
}
ADAM_DEFAULTS: Dict[str, Union[float, bool, int]] = {
"learn_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"eps": 1e-08,
"L2": SGD_DEFAULTS["L2"],
"grad_clip": SGD_DEFAULTS["grad_clip"],
"L2_is_weight_decay": True,
}
@registry.optimizers("RAdam.v1")
def RAdam(
learn_rate: FloatOrSeq = ADAM_DEFAULTS["learn_rate"],
*,
beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"],
beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"],
eps: FloatOrSeq = ADAM_DEFAULTS["eps"],
L2: FloatOrSeq = ADAM_DEFAULTS["L2"],
L2_is_weight_decay: bool = cast(bool, ADAM_DEFAULTS["L2_is_weight_decay"]),
grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"],
use_averages: bool = True,
):
return Optimizer(
learn_rate,
beta1=beta1,
beta2=beta2,
eps=eps,
grad_clip=grad_clip,
L2_is_weight_decay=L2_is_weight_decay,
L2=L2,
use_averages=use_averages,
use_radam=True,
)
@registry.optimizers("Adam.v1")
def Adam(
learn_rate: FloatOrSeq = ADAM_DEFAULTS["learn_rate"],
*,
L2: FloatOrSeq = ADAM_DEFAULTS["L2"],
beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"],
beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"],
eps: FloatOrSeq = ADAM_DEFAULTS["eps"],
grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"],
L2_is_weight_decay: bool = cast(bool, ADAM_DEFAULTS["L2_is_weight_decay"]),
use_averages: bool = True,
):
return Optimizer(
learn_rate,
L2=L2,
beta1=beta1,
beta2=beta2,
eps=eps,
grad_clip=grad_clip,
L2_is_weight_decay=L2_is_weight_decay,
use_averages=use_averages,
use_radam=False,
)
@registry.optimizers("SGD.v1")
def SGD(
learn_rate: FloatOrSeq,
*,
L2: FloatOrSeq = SGD_DEFAULTS["L2"],
grad_clip: FloatOrSeq = SGD_DEFAULTS["grad_clip"],
L2_is_weight_decay: bool = cast(bool, SGD_DEFAULTS["L2_is_weight_decay"]),
use_averages: bool = True,
):
return Optimizer(
learn_rate,
L2=L2,
grad_clip=grad_clip,
L2_is_weight_decay=L2_is_weight_decay,
beta1=0.0,
beta2=0.0,
use_averages=use_averages,
)
class Optimizer(object):
"""Do various flavours of stochastic gradient descent, with first and
second order momentum. Currently support 'vanilla' SGD, Adam, and RAdam.
"""
mom1: Dict[KeyT, FloatsXd]
mom2: Dict[KeyT, FloatsXd]
averages: Optional[Dict[KeyT, FloatsXd]]
schedules: Dict[str, Generator]
nr_update: Dict[KeyT, int]
last_seen: Dict[KeyT, int]
grad_clip: float
learn_rate: float
b1: float
b2: float
eps: float
L2: float
use_radam: bool
L2_is_weight_decay: bool
_radam_buffer: List[List[Optional[FloatsXd]]]
# This "locks" the class, so we get an error if you try to assign to
# an unexpected variable.
__slots__ = [
"mom1",
"mom2",
"averages",
"schedules",
"nr_update",
"last_seen",
"grad_clip",
"learn_rate",
"b1",
"b2",
"eps",
"L2",
"use_radam",
"L2_is_weight_decay",
"_radam_buffer",
]
def __init__(
self,
learn_rate: FloatOrSeq,
*,
L2: FloatOrSeq = ADAM_DEFAULTS["L2"],
beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"],
beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"],
eps: FloatOrSeq = ADAM_DEFAULTS["eps"],
grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"],
use_averages: bool = True,
use_radam: bool = False,
L2_is_weight_decay: bool = True,
):
"""
Initialize an optimizer.
learn_rate (float): The initial learning rate.
L2 (float): The L2 regularization term.
beta1 (float): First-order momentum.
beta2 (float): Second-order momentum.
eps (float): Epsilon term for Adam etc.
grad_clip (float): Gradient clipping.
use_averages (bool): Whether to track moving averages of the parameters.
use_radam (bool): Whether to use the RAdam optimizer.
L2_is_weight_decay (bool): Whether to interpret the L2 parameter as a
weight decay term, in the style of the AdamW optimizer.
"""
self.mom1 = {}
self.mom2 = {}
if use_averages:
self.averages = {}
else:
self.averages = None
self.schedules = {}
self.nr_update = defaultdict(int)
self.last_seen = defaultdict(int)
self._set_attr_or_schedule("grad_clip", grad_clip)
self._set_attr_or_schedule("learn_rate", learn_rate)
self._set_attr_or_schedule("b1", beta1)
self._set_attr_or_schedule("b2", beta2)
self._set_attr_or_schedule("eps", eps)
self._set_attr_or_schedule("L2", L2)
self.use_radam = use_radam
self.L2_is_weight_decay = L2_is_weight_decay
self._radam_buffer = [[None, None, None] for _ in range(10)]
def _set_attr_or_schedule(self, name, value):
if isinstance(value, (float, bool, int)):
setattr(self, name, value)
else:
if isinstance(value, list):
value = iter(value)
self.schedules[name] = value
try:
setattr(self, name, next(value))
except (StopIteration, TypeError) as e:
err = f"Invalid schedule for '{name}' ({type(value)})\n{e}"
raise ValueError(err)
def step_schedules(self):
for key, schedule in self.schedules.items():
try:
value = next(schedule)
except StopIteration: # schedule exhausted, use last value
value = getattr(self, key)
setattr(self, key, value)
def __call__(
self,
key: Tuple[int, str],
weights: FloatsXd,
gradient: FloatsXd,
*,
lr_scale: float = 1.0,
):
"""Call the optimizer with weights and a gradient. The key is the
identifier for the parameter, usually the node ID and parameter name.
"""
if len(gradient) < 1:
return weights, gradient
ops = get_array_ops(weights)
self.nr_update[key] += 1
nr_upd = self.nr_update[key]
if self.L2 != 0 and not self.L2_is_weight_decay:
gradient += self.L2 * weights
if self.grad_clip:
gradient = ops.clip_gradient(gradient, self.grad_clip)
if self.use_radam:
weights, gradient = self._radam(
ops, weights, gradient, lr_scale, key, nr_upd
)
elif self.b1 > 0.0 and self.b2 > 0.0:
weights, gradient = self._adam(
ops, weights, gradient, lr_scale, key, nr_upd
)
elif self.b2 > 0.0: # pragma: no cover
raise NotImplementedError # TODO: error message
else:
weights -= lr_scale * self.learn_rate * gradient
gradient *= 0
if self.L2 != 0 and self.L2_is_weight_decay:
weights -= lr_scale * self.learn_rate * self.L2 * weights
if self.averages is not None:
if key not in self.averages:
self.averages[key] = ops.alloc(weights.shape, dtype="float32")
ops.update_averages(self.averages[key], weights, nr_upd)
return weights, gradient
def _radam(self, ops, weights, grad, lr_scale, key, nr_upd):
if key not in self.mom1:
self.mom1[key] = ops.alloc1f(weights.size)
if key not in self.mom2:
self.mom2[key] = ops.alloc1f(weights.size)
weights_1D = ops.reshape1f(weights, weights.size)
gradient_1D = ops.reshape1f(grad, grad.size)
# While we port from the pytorch implementation, keep some of the same
# naming
state = {
"step": self.nr_update[key],
"exp_avg": self.mom1[key],
"exp_avg_sq": self.mom2[key],
}
group = {
"lr": self.learn_rate,
"betas": [self.b1, self.b2],
"eps": self.eps,
"weight_decay": 0.0,
"buffer": self._radam_buffer,
}
degenerated_to_sgd = True
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
# exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg_sq *= beta2
exp_avg_sq += (1 - beta2) * (gradient_1D**2)
# exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg *= beta1
exp_avg += (1 - beta1) * gradient_1D
state["step"] += 1
buffered = group["buffer"][int(state["step"] % 10)]
if state["step"] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state["step"]
beta2_t = beta2 ** state["step"]
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt(
(1 - beta2_t)
* (N_sma - 4)
/ (N_sma_max - 4)
* (N_sma - 2)
/ N_sma
* N_sma_max
/ (N_sma_max - 2)
) / (1 - beta1 ** state["step"])
elif degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state["step"])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group["weight_decay"] != 0:
weights_1D += -group["weight_decay"] * group["lr"] * weights_1D
denom = ops.xp.sqrt(exp_avg_sq) + group["eps"]
weights_1D += -step_size * group["lr"] * (exp_avg / denom)
elif step_size > 0:
if group["weight_decay"] != 0:
weights_1D += -group["weight_decay"] * group["lr"] * weights_1D
weights_1D += -step_size * group["lr"] * exp_avg
return (
ops.reshape_f(weights_1D, weights.shape),
ops.reshape_f(gradient_1D, grad.shape),
)
def _adam(self, ops, weights, gradient, lr_scale, key, nr_upd):
weights_1D = ops.reshape1f(weights, weights.size)
gradient_1D = ops.reshape1f(gradient, gradient.size)
if key not in self.mom1:
self.mom1[key] = ops.alloc1f(weights.size)
if key not in self.mom2:
self.mom2[key] = ops.alloc1f(weights.size)
mom1 = self.mom1[key]
mom2 = self.mom2[key]
b1 = self.b1
b2 = self.b2
fix1 = 1.0 - (b1**nr_upd)
fix2 = 1.0 - (b2**nr_upd)
lr = self.learn_rate * fix2**0.5 / fix1
eps = self.eps
# needs to be 1D going into the adam function
weights_1D, gradient_1D, mom1, mom2 = ops.adam(
weights_1D, gradient_1D, mom1, mom2, b1, b2, eps, lr * lr_scale
)
self.mom1[key] = mom1
self.mom2[key] = mom2
return (
ops.reshape_f(weights_1D, weights.shape),
ops.reshape_f(gradient_1D, gradient.shape),
)
__all__ = ["Adam", "RAdam", "SGD", "Optimizer", "ADAM_DEFAULTS", "SGD_DEFAULTS"]