import torch import torch.nn as nn from timm.models import create_model class ModifiedDeiT(nn.Module): def __init__(self, num_classes, pretrained=True): super(ModifiedDeiT, self).__init__() if num_classes == 100: #cifar100 self.deit = create_model('deit_tiny_patch16_224', pretrained=True) self.n_inputs = self.deit.head.in_features self.deit.head = torch.nn.Linear(self.n_inputs, num_classes) self.deit.load_state_dict(torch.load('../results/BlackBox/cifar100_deit_weights.pth')) elif num_classes == 10: self.deit = create_model('deit_tiny_patch16_224', pretrained=True) self.n_inputs = self.deit.head.in_features self.deit.head = torch.nn.Linear(self.n_inputs, num_classes) self.deit.load_state_dict(torch.load('../results/BlackBox/cifar10_deit_weights.pth')) elif num_classes == 200: self.deit = create_model('deit_tiny_patch16_224', pretrained=True) self.n_inputs = self.deit.head.in_features self.deit.head = torch.nn.Linear(self.n_inputs, num_classes) self.deit.load_state_dict(torch.load('../results/BlackBox/tiny_deit_weights.pth')) def forward(self, x): # Forward through DeiT up to the transformer blocks x = self.deit.patch_embed(x) cls_tokens = self.deit.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.deit.pos_embed x = self.deit.pos_drop(x) # Store the outputs of each classifier outputs = [] for i, blk in enumerate(self.deit.blocks): x = blk(x) # Apply classifier to the cls_token cls_token_output = x[:, 0, :] # out = self.classifiers[i](cls_token_output) out = self.deit.head(cls_token_output) outputs.append(out) return outputs