import torch import torch.nn as nn import numpy as np def clamp(X, lower_limit, upper_limit): return torch.max(torch.min(X, upper_limit), lower_limit) ''' @Parameter atten_grad, ce_grad: should be 2D tensor with shape [batch_size, -1] ''' def PCGrad(atten_grad, ce_grad, sim, shape): pcgrad = atten_grad[sim < 0] temp_ce_grad = ce_grad[sim < 0] dot_prod = torch.mul(pcgrad, temp_ce_grad).sum(dim=-1) dot_prod = dot_prod / torch.norm(temp_ce_grad, dim=-1) pcgrad = pcgrad - dot_prod.view(-1, 1) * temp_ce_grad atten_grad[sim < 0] = pcgrad atten_grad = atten_grad.view(shape) return atten_grad def backward_hook(gamma): # implement SGM through grad through ReLU def _backward_hook(module, grad_in, grad_out): if isinstance(module, nn.ReLU): return (gamma * grad_in[0],) return _backward_hook def backward_hook_norm(module, grad_in, grad_out): # normalize the gradient to avoid gradient explosion or vanish std = torch.std(grad_in[0]) return (grad_in[0] / std,) def register_hook_for_resnet(model, arch, gamma): backward_hook_sgm = backward_hook(gamma) for name, module in model.named_modules(): if 'relu' in name and not '0.relu' in name: module.register_backward_hook(backward_hook_sgm) if len(name.split('.')) >= 2 and 'layer' in name.split('.')[-2]: module.register_backward_hook(backward_hook_norm)