from functools import partial import torch import torch.nn as nn import math from einops import reduce, rearrange from timm.models.registry import register_model from timm.models.vision_transformer import VisionTransformer, _cfg from timm.models import create_model import torch.nn.functional as F class TransformerHead(nn.Module): expansion = 1 def __init__(self, token_dim, num_patches=196, num_classes=1000, stride=1): super(TransformerHead, self).__init__() self.token_dim = token_dim self.num_patches = num_patches self.num_classes = num_classes # To process patches self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=stride, padding=1, bias=False) self.bn = nn.BatchNorm2d(self.token_dim) self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=1, padding=1, bias=False) self.bn = nn.BatchNorm2d(self.token_dim) self.shortcut = nn.Sequential() if stride != 1 or self.token_dim != self.expansion * self.token_dim: self.shortcut = nn.Sequential( nn.Conv2d(self.token_dim, self.expansion * self.token_dim, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * self.token_dim) ) self.token_fc = nn.Linear(self.token_dim, self.token_dim) def forward(self, x): """ x : (B, num_patches + 1, D) -> (B, C=num_classes) """ cls_token, patch_tokens = x[:, 0], x[:, 1:] size = int(math.sqrt(x.shape[1])) patch_tokens = rearrange(patch_tokens, 'b (h w) d -> b d h w', h=size, w=size) # B, D, H, W features = F.relu(self.bn(self.conv(patch_tokens))) features = self.bn(self.conv(features)) features += self.shortcut(patch_tokens) features = F.relu(features) patch_tokens = F.avg_pool2d(features, 14).view(-1, self.token_dim) cls_token = self.token_fc(cls_token) out = patch_tokens + cls_token return out class VisionTransformer_hierarchical(VisionTransformer): def __init__(self, num_classes, *args, **kwargs): super().__init__(*args, **kwargs) if num_classes == 100: #cifar100 self.deit = create_model('deit_tiny_patch16_224', pretrained=True) self.n_inputs = self.deit.head.in_features self.deit.head = torch.nn.Linear(self.n_inputs, num_classes) self.deit.load_state_dict(torch.load('../results/BlackBox/cifar100_deit_weights.pth')) elif num_classes == 10: self.deit = create_model('deit_tiny_patch16_224', pretrained=True) self.n_inputs = self.deit.head.in_features self.deit.head = torch.nn.Linear(self.n_inputs, num_classes) self.deit.load_state_dict(torch.load('../results/BlackBox/cifar10_deit_weights.pth')) elif num_classes == 200: self.deit = create_model('deit_tiny_patch16_224', pretrained=True) self.n_inputs = self.deit.head.in_features self.deit.head = torch.nn.Linear(self.n_inputs, num_classes) self.deit.load_state_dict(torch.load('../results/BlackBox/tiny_deit_weights.pth')) # Transformer heads self.transformerheads = nn.Sequential(*[ TransformerHead(self.deit.embed_dim) for i in range(11)]) def forward_features(self, x): B, nc, w, h = x.shape x = self.deit.patch_embed(x) cls_tokens = self.deit.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.deit.pos_embed x = self.deit.pos_drop(x) # Store transformer outputs transformerheads_outputs = [] for idx, blk in enumerate(self.deit.blocks): x = blk(x) if idx <= 10: out = self.deit.norm(x) out = self.transformerheads[idx](out) transformerheads_outputs.append(out) x = self.deit.norm(x) return x, transformerheads_outputs def forward(self, x): x, transformerheads_outputs = self.forward_features(x) output = [] for y in transformerheads_outputs: output.append(self.deit.head(y)) output.append(self.deit.head(x[:, 0])) return output