ViTGuard / DataManagerPytorch.py
DataManagerPytorch.py
Raw
import torch 

#Class to help with converting between dataloader and pytorch tensor 
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self, x_tensor, y_tensor, transforms=None):
        self.x = x_tensor
        self.y = y_tensor
        self.transforms = transforms

    def __getitem__(self, index):
        if self.transforms is None: #No transform so return the data directly
            return (self.x[index], self.y[index])
        else: #Transform so apply it to the data before returning 
            return (self.transforms(self.x[index]), self.y[index])

    def __len__(self):
        return len(self.x)

#Convert a X and Y tensors into a dataloader
#Does not put any transforms with the data  
def TensorToDataLoader(xData, yData, transforms= None, batchSize=None, randomizer = None):
    if batchSize is None: #If no batch size put all the data through 
        batchSize = xData.shape[0]
    dataset = MyDataSet(xData, yData, transforms)
    if randomizer == None: #No randomizer
        dataLoader = torch.utils.data.DataLoader(dataset=dataset,  batch_size=batchSize, shuffle=False)
    else: #randomizer needed 
        train_sampler = torch.utils.data.RandomSampler(dataset)
        dataLoader = torch.utils.data.DataLoader(dataset=dataset,  batch_size=batchSize, sampler=train_sampler, shuffle=False)
    return dataLoader

#Get the output shape from the dataloader
def GetOutputShape(dataLoader):
    for i, (input, target) in enumerate(dataLoader):
        return input[0].shape

#Get the dataset reconstructed by ViTMAE
def get_reconstructed_dataset(model, data_loader, device, random_seed = None):
    torch.manual_seed(2)
    model.eval()

    batchSize = data_loader.batch_size
    numSamples = len(data_loader.dataset)
    xShape = GetOutputShape(data_loader)
    xRe = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    yClean = yClean.type(torch.LongTensor) 

    advSampleIndex = 0 
    for x_batch, y_batch in data_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        outputs = model(x_batch, random_seed=random_seed, preset_mask=False)
        y = model.module.unpatchify(outputs[1].detach().cpu()) #outputs[1] is logits 

        # visualize the mask
        mask = outputs[2].detach().cpu() 
        #Repeats this tensor along the specified dimensions.
        mask = mask.unsqueeze(-1).repeat(1, 1, model.module.config.patch_size**2 *3)  # (N, H*W, p*p*3) 
        mask = model.module.unpatchify(mask)  # 1 is removing, 0 is keeping. 
        # MAE reconstruction pasted with visible patches
        im_paste = x_batch.detach().cpu() * (1 - mask) + y * mask
        im_paste = torch.clamp(im_paste, 0, 1)
        for j in range(0, batchSize):
            xRe[advSampleIndex] = im_paste[j]
            yClean[advSampleIndex] = y_batch[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index
            if advSampleIndex == numSamples:
                break
        del outputs, y, mask, im_paste
    reLoader_adv = TensorToDataLoader(xRe, yClean, transforms= None, batchSize= batchSize, randomizer=None)
    return reLoader_adv

#Apply ViTMAE twice and get the reconstructed dataset
def recoverall(model, data_loader, device, salient_index, random_seed = None):
    torch.manual_seed(2)
    model.eval()

    batchSize = data_loader.batch_size
    numSamples = len(data_loader.dataset)
    xShape = GetOutputShape(data_loader)
    xRe = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    yClean = yClean.type(torch.LongTensor) 
    tracker = 0

    advSampleIndex = 0 
    for x_batch, y_batch in data_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        batch_size = len(x_batch)
        salient_index_batch = (1-salient_index[tracker:tracker+batch_size]) #In MAE, 1 is removing, 0 is keeping.
        outputs = model(x_batch, random_seed=random_seed, preset_mask=salient_index_batch)
        y = model.module.unpatchify(outputs[1].detach().cpu())
        mask = outputs[2].detach().cpu() 
        #Repeats this tensor along the specified dimensions.
        mask = mask.unsqueeze(-1).repeat(1, 1, model.module.config.patch_size**2 *3)  
        mask = model.module.unpatchify(mask)  
        im_paste = y * mask
        im_paste = torch.clamp(im_paste, 0, 1)
        del outputs, y, mask

        salient_index_batch = salient_index[tracker:tracker+batch_size] 
        tracker += batch_size  
        outputs = model(x_batch, random_seed=random_seed, preset_mask=salient_index_batch)
        y = model.module.unpatchify(outputs[1].detach().cpu()) #outputs[1] is logits 
        mask = outputs[2].detach().cpu() 
        #Repeats this tensor along the specified dimensions.
        mask = mask.unsqueeze(-1).repeat(1, 1, model.module.config.patch_size**2 *3) 
        mask = model.module.unpatchify(mask) 
        im_paste_2 = im_paste + y * mask
        im_paste_2 = torch.clamp(im_paste_2, 0, 1)
        for j in range(0, batch_size):
            xRe[advSampleIndex] = im_paste_2[j]
            yClean[advSampleIndex] = y_batch[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index
            if advSampleIndex == numSamples:
                break
        del outputs, y, mask, im_paste_2
    
    reLoader_adv = TensorToDataLoader(xRe, yClean, transforms= None, batchSize= batchSize, randomizer=None)
    return reLoader_adv


def get_reconstructed_dataset_salient(model, data_loader, device, salient_index, random_seed = None):
    torch.manual_seed(2)
    model.eval()

    batchSize = data_loader.batch_size
    numSamples = len(data_loader.dataset)
    xShape = GetOutputShape(data_loader)
    xRe = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    yClean = yClean.type(torch.LongTensor) 
    tracker = 0

    advSampleIndex = 0 
    for x_batch, y_batch in data_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        batch_size = len(x_batch)
        salient_index_batch = (1-salient_index[tracker:tracker+batch_size]) #in MAE, 0 is keep and 1 is remove
        tracker += batch_size
    
        outputs = model(x_batch, random_seed=random_seed, preset_mask=salient_index_batch)
        y = model.module.unpatchify(outputs[1].detach().cpu()) #outputs[1] is logits 
        # print(y.shape) [N, 3, 224, 224]

        # visualize the mask
        # mask = outputs.mask.detach() #[1,196]
        mask = outputs[2].detach().cpu() #[1,196] outputs[2] is mask
        #Repeats this tensor along the specified dimensions.
        mask = mask.unsqueeze(-1).repeat(1, 1, model.module.config.patch_size**2 *3)  # (N, H*W, p*p*3) [1,196,768]
        mask = model.module.unpatchify(mask)  # 1 is removing, 0 is keeping. [1,3,224,224]
        # MAE reconstruction pasted with visible patches
        im_paste = x_batch.detach().cpu() * (1 - mask) + y * mask
        im_paste = torch.clamp(im_paste, 0, 1)
        for j in range(0, batch_size):
            xRe[advSampleIndex] = im_paste[j]
            yClean[advSampleIndex] = y_batch[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index
            if advSampleIndex == numSamples:
                break
        del outputs, y, mask, im_paste
    reLoader_adv = TensorToDataLoader(xRe, yClean, transforms= None, batchSize= batchSize, randomizer=None)
    return reLoader_adv