from typing import Callable, Iterable, Tuple import math import torch from torch.optim import Optimizer class AdamW(Optimizer): def __init__( self, params: Iterable[torch.nn.parameter.Parameter], lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-6, weight_decay: float = 0.0, correct_bias: bool = True, ): if lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) if not 0.0 <= betas[0] < 1.0: raise ValueError( "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]) ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias, ) super().__init__(params, defaults) def step(self, closure: Callable = None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError( "Adam does not support sparse gradients, please consider SparseAdam instead" ) # State should be stored in this dictionary state = self.state[p] # Access hyperparameters from the `group` dictionary alpha: float = group["lr"] # These are copied by value. beta1: float = group["betas"][0] beta2: float = group["betas"][1] eps: float = group["eps"] lambda_: float = group["weight_decay"] correct_bias: bool = group["correct_bias"] # Update first and second moments of the gradients # Initialize moments. if "m" not in state: state["m"] = torch.zeros_like(grad) if "v" not in state: state["v"] = torch.zeros_like(grad) state["m"] = (beta1 * state["m"]) + ((1 - beta1) * grad) state["v"] = (beta2 * state["v"]) + ((1 - beta2) * torch.square(grad)) # Bias correction # Please note that we are using the "efficient version" given in # https://arxiv.org/abs/1412.6980 if "step" not in state: state["step"] = 0 state["step"] += 1 t = state["step"] # This is copied by value. if correct_bias: alpha_t = ( alpha * math.sqrt(1 - math.pow(beta2, t)) / (1 - math.pow(beta1, t)) ) else: alpha_t = alpha # Update parameters # Add weight decay after the main gradient-based updates. # Please note that the learning rate should be incorporated into this update. p.data = p.data - ( (alpha_t * torch.div(state["m"], (torch.sqrt(state["v"]) + eps))) + (alpha * lambda_ * p.data) ) return loss