import torch #Class to help with converting between dataloader and pytorch tensor class MyDataSet(torch.utils.data.Dataset): def __init__(self, x_tensor, y_tensor, transforms=None): self.x = x_tensor self.y = y_tensor self.transforms = transforms def __getitem__(self, index): if self.transforms is None: #No transform so return the data directly return (self.x[index], self.y[index]) else: #Transform so apply it to the data before returning return (self.transforms(self.x[index]), self.y[index]) def __len__(self): return len(self.x) #Convert a X and Y tensors into a dataloader #Does not put any transforms with the data def TensorToDataLoader(xData, yData, transforms= None, batchSize=None, randomizer = None): if batchSize is None: #If no batch size put all the data through batchSize = xData.shape[0] dataset = MyDataSet(xData, yData, transforms) if randomizer == None: #No randomizer dataLoader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batchSize, shuffle=False) else: #randomizer needed train_sampler = torch.utils.data.RandomSampler(dataset) dataLoader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batchSize, sampler=train_sampler, shuffle=False) return dataLoader #Get the output shape from the dataloader def GetOutputShape(dataLoader): for i, (input, target) in enumerate(dataLoader): return input[0].shape #Get the dataset reconstructed by ViTMAE def get_reconstructed_dataset(model, data_loader, device, random_seed = None): torch.manual_seed(2) model.eval() batchSize = data_loader.batch_size numSamples = len(data_loader.dataset) xShape = GetOutputShape(data_loader) xRe = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) yClean = yClean.type(torch.LongTensor) advSampleIndex = 0 for x_batch, y_batch in data_loader: x_batch, y_batch = x_batch.to(device), y_batch.to(device) outputs = model(x_batch, random_seed=random_seed, preset_mask=False) y = model.module.unpatchify(outputs[1].detach().cpu()) #outputs[1] is logits # visualize the mask mask = outputs[2].detach().cpu() #Repeats this tensor along the specified dimensions. mask = mask.unsqueeze(-1).repeat(1, 1, model.module.config.patch_size**2 *3) # (N, H*W, p*p*3) mask = model.module.unpatchify(mask) # 1 is removing, 0 is keeping. # MAE reconstruction pasted with visible patches im_paste = x_batch.detach().cpu() * (1 - mask) + y * mask im_paste = torch.clamp(im_paste, 0, 1) for j in range(0, batchSize): xRe[advSampleIndex] = im_paste[j] yClean[advSampleIndex] = y_batch[j] advSampleIndex = advSampleIndex+1 #increment the sample index if advSampleIndex == numSamples: break del outputs, y, mask, im_paste reLoader_adv = TensorToDataLoader(xRe, yClean, transforms= None, batchSize= batchSize, randomizer=None) return reLoader_adv #Apply ViTMAE twice and get the reconstructed dataset def recoverall(model, data_loader, device, salient_index, random_seed = None): torch.manual_seed(2) model.eval() batchSize = data_loader.batch_size numSamples = len(data_loader.dataset) xShape = GetOutputShape(data_loader) xRe = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) yClean = yClean.type(torch.LongTensor) tracker = 0 advSampleIndex = 0 for x_batch, y_batch in data_loader: x_batch, y_batch = x_batch.to(device), y_batch.to(device) batch_size = len(x_batch) salient_index_batch = (1-salient_index[tracker:tracker+batch_size]) #In MAE, 1 is removing, 0 is keeping. outputs = model(x_batch, random_seed=random_seed, preset_mask=salient_index_batch) y = model.module.unpatchify(outputs[1].detach().cpu()) mask = outputs[2].detach().cpu() #Repeats this tensor along the specified dimensions. mask = mask.unsqueeze(-1).repeat(1, 1, model.module.config.patch_size**2 *3) mask = model.module.unpatchify(mask) im_paste = y * mask im_paste = torch.clamp(im_paste, 0, 1) del outputs, y, mask salient_index_batch = salient_index[tracker:tracker+batch_size] tracker += batch_size outputs = model(x_batch, random_seed=random_seed, preset_mask=salient_index_batch) y = model.module.unpatchify(outputs[1].detach().cpu()) #outputs[1] is logits mask = outputs[2].detach().cpu() #Repeats this tensor along the specified dimensions. mask = mask.unsqueeze(-1).repeat(1, 1, model.module.config.patch_size**2 *3) mask = model.module.unpatchify(mask) im_paste_2 = im_paste + y * mask im_paste_2 = torch.clamp(im_paste_2, 0, 1) for j in range(0, batch_size): xRe[advSampleIndex] = im_paste_2[j] yClean[advSampleIndex] = y_batch[j] advSampleIndex = advSampleIndex+1 #increment the sample index if advSampleIndex == numSamples: break del outputs, y, mask, im_paste_2 reLoader_adv = TensorToDataLoader(xRe, yClean, transforms= None, batchSize= batchSize, randomizer=None) return reLoader_adv def get_reconstructed_dataset_salient(model, data_loader, device, salient_index, random_seed = None): torch.manual_seed(2) model.eval() batchSize = data_loader.batch_size numSamples = len(data_loader.dataset) xShape = GetOutputShape(data_loader) xRe = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) yClean = yClean.type(torch.LongTensor) tracker = 0 advSampleIndex = 0 for x_batch, y_batch in data_loader: x_batch, y_batch = x_batch.to(device), y_batch.to(device) batch_size = len(x_batch) salient_index_batch = (1-salient_index[tracker:tracker+batch_size]) #in MAE, 0 is keep and 1 is remove tracker += batch_size outputs = model(x_batch, random_seed=random_seed, preset_mask=salient_index_batch) y = model.module.unpatchify(outputs[1].detach().cpu()) #outputs[1] is logits # print(y.shape) [N, 3, 224, 224] # visualize the mask # mask = outputs.mask.detach() #[1,196] mask = outputs[2].detach().cpu() #[1,196] outputs[2] is mask #Repeats this tensor along the specified dimensions. mask = mask.unsqueeze(-1).repeat(1, 1, model.module.config.patch_size**2 *3) # (N, H*W, p*p*3) [1,196,768] mask = model.module.unpatchify(mask) # 1 is removing, 0 is keeping. [1,3,224,224] # MAE reconstruction pasted with visible patches im_paste = x_batch.detach().cpu() * (1 - mask) + y * mask im_paste = torch.clamp(im_paste, 0, 1) for j in range(0, batch_size): xRe[advSampleIndex] = im_paste[j] yClean[advSampleIndex] = y_batch[j] advSampleIndex = advSampleIndex+1 #increment the sample index if advSampleIndex == numSamples: break del outputs, y, mask, im_paste reLoader_adv = TensorToDataLoader(xRe, yClean, transforms= None, batchSize= batchSize, randomizer=None) return reLoader_adv