ViTGuard / detection / utils.py
utils.py
Raw
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

def vis_from_attn(att_mat, layer_index=-1):
    att_mat = att_mat.to('cpu')
    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=1)

    residual_att = torch.eye(att_mat.size(-1)) 
    aug_att_mat = att_mat + residual_att 
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) #[num_layers, N, num_tokens, num_tokens]
    
    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]
    
    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    # Attention from the output token to the input space.
    v = joint_attentions[layer_index] 
    att_matrix = v[:,0, 1:].detach().numpy()
    return att_matrix #(N, 196)

# Get the attention vector from the specific transformer layer
def get_attn(dataloader, model, device, layer_index):
    attn_all = []
    for x_batch, y_batch in dataloader:
        x_batch = x_batch.to(device)
        preds = model(x_batch, output_attentions=True)
        att_mat = [t.unsqueeze(0) for t in preds[-1]]
        att_mat = torch.cat(att_mat, dim=0)
        att_mat = torch.transpose(att_mat, 1, 2).detach().cpu()
        del preds
        att_matrix = vis_from_attn(att_mat, layer_index)
        attn_all.append(att_matrix)
    attn_all = np.vstack(attn_all)
    return attn_all

# Extract the successful adversarial samples
def get_success_adv_index(test_loader, advLoader, model, device):
    correct_index = []
    batch_size = 64
    tracker = 0
    for x_batch, y_batch in test_loader:
        x_batch = x_batch.to(device)
        output = model(x_batch)
        pred = torch.argmax(output, 1).detach().cpu()
        index_temp = np.where(pred.numpy() == y_batch.numpy())[0]
        for i in index_temp:
            correct_index.append(i+batch_size*tracker)
        tracker += 1

    adv_index = []
    batch_size = 64
    tracker = 0
    for x_batch, y_batch in advLoader:
        x_batch = x_batch.to(device)
        output = model(x_batch)
        pred = torch.argmax(output, 1).detach().cpu()
        index_temp = np.where(pred.numpy() != y_batch.numpy())[0]
        for i in index_temp:
            adv_index.append(i+batch_size*tracker)
        tracker += 1

    detect_index = np.intersect1d(np.asarray(correct_index), np.asarray(adv_index))
    return detect_index

def l2_distance(matrix1, matrix2):
    dis = np.linalg.norm(matrix1 - matrix2,axis=1)
    return dis

def get_threshold(error_adv_all, drop_rate):#drop_rate is fpr
    num = int(len(error_adv_all) * drop_rate)
    marks = np.sort(error_adv_all)
    thrs = marks[-num]
    return thrs, marks

# Get the CLS representations from the specific transformer layer
def get_cls(test_loader, model, device, layer_index):
    features = {}
    def get_features(name):
        def hook(_, input, output):
            features[output[0].get_device()] = output[0][:,0].detach().cpu()
        return hook 
    model.module.vit.encoder.layer[layer_index].register_forward_hook(get_features('cls'))

    cls_outputs = []
    for x_batch, y_batch in test_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        preds = model(x_batch)
        for index in list(features.keys()):
            cls_outputs.append(features[index])
        del preds
    cls_outputs = np.vstack(cls_outputs)
    return cls_outputs

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
        
def remove_nan_from_dataset(dataset_adv, dataset_clean):
    adv_data = []
    adv_labels = []

    clean_data = []
    clean_labels = []

    for (data, label), (data_c, label_c)  in zip(dataset_adv, dataset_clean):
        for i in range(len(data)):
            if not torch.isnan(data[i]).any() and not torch.isnan(label[i]).any():
                adv_data.append(data[i])
                adv_labels.append(label[i])
                clean_data.append(data_c[i])
                clean_labels.append(label_c[i])
    advLoader = CustomDataset(torch.stack(adv_data), torch.stack(adv_labels))
    test_loader = CustomDataset(torch.stack(clean_data), torch.stack(clean_labels))
    test_loader = torch.utils.data.DataLoader(test_loader, batch_size=64, shuffle=False, pin_memory=True)
    advLoader = torch.utils.data.DataLoader(advLoader, batch_size=64, shuffle=False,  pin_memory=True)
    
    return advLoader, test_loader