ViTGuard / target_models / run / source_model / deit_ensemble.py
deit_ensemble.py
Raw
import torch
import torch.nn as nn
from timm.models import create_model


class ModifiedDeiT(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(ModifiedDeiT, self).__init__()
        if num_classes == 100: #cifar100
            self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
            self.n_inputs = self.deit.head.in_features
            self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
            self.deit.load_state_dict(torch.load('../results/BlackBox/cifar100_deit_weights.pth'))
        elif num_classes == 10:
            self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
            self.n_inputs = self.deit.head.in_features
            self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
            self.deit.load_state_dict(torch.load('../results/BlackBox/cifar10_deit_weights.pth'))
        elif num_classes == 200:
            self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
            self.n_inputs = self.deit.head.in_features
            self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
            self.deit.load_state_dict(torch.load('../results/BlackBox/tiny_deit_weights.pth'))
        
    def forward(self, x):
        # Forward through DeiT up to the transformer blocks
        x = self.deit.patch_embed(x)
        cls_tokens = self.deit.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.deit.pos_embed
        x = self.deit.pos_drop(x)
        
        # Store the outputs of each classifier
        outputs = []
        
        for i, blk in enumerate(self.deit.blocks):
            x = blk(x)
            # Apply classifier to the cls_token
            cls_token_output = x[:, 0, :]
#             out = self.classifiers[i](cls_token_output)
            out = self.deit.head(cls_token_output)
            outputs.append(out)

        return outputs