import torch
import torch.nn as nn
from timm.models import create_model
class ModifiedDeiT(nn.Module):
def __init__(self, num_classes, pretrained=True):
super(ModifiedDeiT, self).__init__()
if num_classes == 100: #cifar100
self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
self.n_inputs = self.deit.head.in_features
self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
self.deit.load_state_dict(torch.load('../results/BlackBox/cifar100_deit_weights.pth'))
elif num_classes == 10:
self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
self.n_inputs = self.deit.head.in_features
self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
self.deit.load_state_dict(torch.load('../results/BlackBox/cifar10_deit_weights.pth'))
elif num_classes == 200:
self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
self.n_inputs = self.deit.head.in_features
self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
self.deit.load_state_dict(torch.load('../results/BlackBox/tiny_deit_weights.pth'))
def forward(self, x):
# Forward through DeiT up to the transformer blocks
x = self.deit.patch_embed(x)
cls_tokens = self.deit.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.deit.pos_embed
x = self.deit.pos_drop(x)
# Store the outputs of each classifier
outputs = []
for i, blk in enumerate(self.deit.blocks):
x = blk(x)
# Apply classifier to the cls_token
cls_token_output = x[:, 0, :]
# out = self.classifiers[i](cls_token_output)
out = self.deit.head(cls_token_output)
outputs.append(out)
return outputs