ViTGuard / target_models / run / source_model / deit_tr.py
deit_tr.py
Raw
from functools import partial

import torch
import torch.nn as nn
import math
from einops import reduce, rearrange
from timm.models.registry import register_model
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models import create_model

import torch.nn.functional as F

class TransformerHead(nn.Module):
    expansion = 1

    def __init__(self, token_dim, num_patches=196, num_classes=1000, stride=1):
        super(TransformerHead, self).__init__()

        self.token_dim = token_dim
        self.num_patches = num_patches
        self.num_classes = num_classes

        # To process patches
        self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(self.token_dim)
        self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(self.token_dim)

        self.shortcut = nn.Sequential()
        if stride != 1 or self.token_dim != self.expansion * self.token_dim:
            self.shortcut = nn.Sequential(
                nn.Conv2d(self.token_dim, self.expansion * self.token_dim, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * self.token_dim)
            )

        self.token_fc = nn.Linear(self.token_dim, self.token_dim)

    def forward(self, x):
        """
            x : (B, num_patches + 1, D) -> (B, C=num_classes)
        """
        cls_token, patch_tokens = x[:, 0], x[:, 1:]
        size = int(math.sqrt(x.shape[1]))

        patch_tokens = rearrange(patch_tokens, 'b (h w) d -> b d h w', h=size, w=size)  # B, D, H, W
        features = F.relu(self.bn(self.conv(patch_tokens)))
        features = self.bn(self.conv(features))
        features += self.shortcut(patch_tokens)
        features = F.relu(features)
        patch_tokens = F.avg_pool2d(features, 14).view(-1, self.token_dim)
        cls_token = self.token_fc(cls_token)

        out = patch_tokens + cls_token

        return out

class VisionTransformer_hierarchical(VisionTransformer):
    def __init__(self, num_classes, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        if num_classes == 100: #cifar100
            self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
            self.n_inputs = self.deit.head.in_features
            self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
            self.deit.load_state_dict(torch.load('../results/BlackBox/cifar100_deit_weights.pth'))
        elif num_classes == 10:
            self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
            self.n_inputs = self.deit.head.in_features
            self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
            self.deit.load_state_dict(torch.load('../results/BlackBox/cifar10_deit_weights.pth'))
        elif num_classes == 200:
            self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
            self.n_inputs = self.deit.head.in_features
            self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
            self.deit.load_state_dict(torch.load('../results/BlackBox/tiny_deit_weights.pth'))
            
        # Transformer heads
        self.transformerheads = nn.Sequential(*[
            TransformerHead(self.deit.embed_dim)
            for i in range(11)])

    def forward_features(self, x):
        B, nc, w, h = x.shape
        x = self.deit.patch_embed(x)

        cls_tokens = self.deit.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.deit.pos_embed
        x = self.deit.pos_drop(x)

        # Store transformer outputs
        transformerheads_outputs = []

        for idx, blk in enumerate(self.deit.blocks):
            x = blk(x)
            if idx <= 10:
                out = self.deit.norm(x)
                out = self.transformerheads[idx](out)
                transformerheads_outputs.append(out)

        x = self.deit.norm(x)
        return x, transformerheads_outputs

    def forward(self, x):
        x, transformerheads_outputs = self.forward_features(x)
        output = []
        for y in transformerheads_outputs:
            output.append(self.deit.head(y))
        output.append(self.deit.head(x[:, 0]))
        return output