import torch
import torch.nn as nn
import numpy as np
def clamp(X, lower_limit, upper_limit):
return torch.max(torch.min(X, upper_limit), lower_limit)
'''
@Parameter atten_grad, ce_grad: should be 2D tensor with shape [batch_size, -1]
'''
def PCGrad(atten_grad, ce_grad, sim, shape):
pcgrad = atten_grad[sim < 0]
temp_ce_grad = ce_grad[sim < 0]
dot_prod = torch.mul(pcgrad, temp_ce_grad).sum(dim=-1)
dot_prod = dot_prod / torch.norm(temp_ce_grad, dim=-1)
pcgrad = pcgrad - dot_prod.view(-1, 1) * temp_ce_grad
atten_grad[sim < 0] = pcgrad
atten_grad = atten_grad.view(shape)
return atten_grad
def backward_hook(gamma):
# implement SGM through grad through ReLU
def _backward_hook(module, grad_in, grad_out):
if isinstance(module, nn.ReLU):
return (gamma * grad_in[0],)
return _backward_hook
def backward_hook_norm(module, grad_in, grad_out):
# normalize the gradient to avoid gradient explosion or vanish
std = torch.std(grad_in[0])
return (grad_in[0] / std,)
def register_hook_for_resnet(model, arch, gamma):
backward_hook_sgm = backward_hook(gamma)
for name, module in model.named_modules():
if 'relu' in name and not '0.relu' in name:
module.register_backward_hook(backward_hook_sgm)
if len(name.split('.')) >= 2 and 'layer' in name.split('.')[-2]:
module.register_backward_hook(backward_hook_norm)