from typing import Callable, Iterable, Tuple
import math
import torch
from torch.optim import Optimizer
class AdamW(Optimizer):
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
correct_bias=correct_bias,
)
super().__init__(params, defaults)
def step(self, closure: Callable = None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
# State should be stored in this dictionary
state = self.state[p]
# Access hyperparameters from the `group` dictionary
alpha: float = group["lr"] # These are copied by value.
beta1: float = group["betas"][0]
beta2: float = group["betas"][1]
eps: float = group["eps"]
lambda_: float = group["weight_decay"]
correct_bias: bool = group["correct_bias"]
# Update first and second moments of the gradients
# Initialize moments.
if "m" not in state:
state["m"] = torch.zeros_like(grad)
if "v" not in state:
state["v"] = torch.zeros_like(grad)
state["m"] = (beta1 * state["m"]) + ((1 - beta1) * grad)
state["v"] = (beta2 * state["v"]) + ((1 - beta2) * torch.square(grad))
# Bias correction
# Please note that we are using the "efficient version" given in
# https://arxiv.org/abs/1412.6980
if "step" not in state:
state["step"] = 0
state["step"] += 1
t = state["step"] # This is copied by value.
if correct_bias:
alpha_t = (
alpha
* math.sqrt(1 - math.pow(beta2, t))
/ (1 - math.pow(beta1, t))
)
else:
alpha_t = alpha
# Update parameters
# Add weight decay after the main gradient-based updates.
# Please note that the learning rate should be incorporated into this update.
p.data = p.data - (
(alpha_t * torch.div(state["m"], (torch.sqrt(state["v"]) + eps)))
+ (alpha * lambda_ * p.data)
)
return loss