# Copied from another repo, but I can't remember exactly which one.
from collections.abc import Iterable
import torch
class EMAModuleWrapper:
def __init__(
self,
parameters: Iterable[torch.nn.Parameter],
decay: float = 0.9999,
update_step_interval: int = 1,
device: torch.device | None = None,
):
parameters = list(parameters)
self.ema_parameters = [p.clone().detach().to(device) for p in parameters]
self.temp_stored_parameters = None
self.decay = decay
self.update_step_interval = update_step_interval
self.device = device
def get_current_decay(self, optimization_step) -> float:
return min((1 + optimization_step) / (10 + optimization_step), self.decay)
@torch.no_grad()
def step(self, parameters: Iterable[torch.nn.Parameter], optimization_step):
parameters = list(parameters)
one_minus_decay = 1 - self.get_current_decay(optimization_step)
if (optimization_step + 1) % self.update_step_interval == 0:
for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
if parameter.requires_grad:
if ema_parameter.device == parameter.device:
ema_parameter.add_(one_minus_decay * (parameter - ema_parameter))
else:
# in place calculations to save memory
parameter_copy = parameter.detach().to(ema_parameter.device)
parameter_copy.sub_(ema_parameter)
parameter_copy.mul_(one_minus_decay)
ema_parameter.add_(parameter_copy)
del parameter_copy
def to(self, device: torch.device = None, dtype: torch.dtype = None) -> None:
self.device = device
self.ema_parameters = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.ema_parameters
]
@torch.no_grad()
def sync_with_model(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Force the EMA parameters to be a direct copy of the given model parameters.
This is used to create a snapshot for the rollout policy.
"""
parameters = list(parameters)
for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
ema_parameter.data.copy_(parameter.detach().data)
def copy_ema_to(self, parameters: Iterable[torch.nn.Parameter], store_temp: bool = True, grad=False) -> None:
if store_temp:
if grad:
self.temp_stored_parameters = [parameter.data.clone() for parameter in parameters]
else:
self.temp_stored_parameters = [parameter.detach().cpu() for parameter in parameters]
parameters = list(parameters)
for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
parameter.data.copy_(ema_parameter.to(parameter.device).data)
def copy_temp_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
for temp_parameter, parameter in zip(self.temp_stored_parameters, parameters, strict=True):
# Ensure the temp parameter is on the right device
parameter.data.copy_(temp_parameter.to(parameter.device))
self.temp_stored_parameters = None
def load_state_dict(self, state_dict: dict) -> None:
self.decay = self.decay if self.decay else state_dict.get("decay", self.decay)
self.ema_parameters = state_dict.get("ema_parameters")
self.to(self.device)
def state_dict(self) -> dict:
return {
"decay": self.decay,
"ema_parameters": self.ema_parameters,
}