import torch import torch.nn.functional as F from .losses import Loss class VICReg(Loss): def __init__(self, sim_weight=25.0, var_weight=25.0, cov_weight=1.0, eps=1e-4): super(VICReg, self).__init__() self.sim_weight = sim_weight self.var_weight = var_weight self.cov_weight = cov_weight self.eps = eps @staticmethod def invariance_loss(h1, h2): return F.mse_loss(h1, h2) def variance_loss(self, h1, h2): std_z1 = torch.sqrt(h1.var(dim=0) + self.eps) std_z2 = torch.sqrt(h2.var(dim=0) + self.eps) std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2)) return std_loss @staticmethod def covariance_loss(h1, h2): num_nodes, hidden_dim = h1.size() h1 = h1 - h1.mean(dim=0) h2 = h2 - h2.mean(dim=0) cov_z1 = (h1.T @ h1) / (num_nodes - 1) cov_z2 = (h2.T @ h2) / (num_nodes - 1) diag = torch.eye(hidden_dim, device=h1.device) cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / hidden_dim + cov_z2[~diag.bool()].pow_(2).sum() / hidden_dim return cov_loss def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor: sim_loss = self.invariance_loss(anchor, sample) var_loss = self.variance_loss(anchor, sample) cov_loss = self.covariance_loss(anchor, sample) loss = self.sim_weight * sim_loss + self.var_weight * var_loss + self.cov_weight * cov_loss return loss.mean()