from functools import partial
import torch
import torch.nn as nn
import math
from einops import reduce, rearrange
from timm.models.registry import register_model
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models import create_model
import torch.nn.functional as F
class TransformerHead(nn.Module):
expansion = 1
def __init__(self, token_dim, num_patches=196, num_classes=1000, stride=1):
super(TransformerHead, self).__init__()
self.token_dim = token_dim
self.num_patches = num_patches
self.num_classes = num_classes
# To process patches
self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn = nn.BatchNorm2d(self.token_dim)
self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(self.token_dim)
self.shortcut = nn.Sequential()
if stride != 1 or self.token_dim != self.expansion * self.token_dim:
self.shortcut = nn.Sequential(
nn.Conv2d(self.token_dim, self.expansion * self.token_dim, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * self.token_dim)
)
self.token_fc = nn.Linear(self.token_dim, self.token_dim)
def forward(self, x):
"""
x : (B, num_patches + 1, D) -> (B, C=num_classes)
"""
cls_token, patch_tokens = x[:, 0], x[:, 1:]
size = int(math.sqrt(x.shape[1]))
patch_tokens = rearrange(patch_tokens, 'b (h w) d -> b d h w', h=size, w=size) # B, D, H, W
features = F.relu(self.bn(self.conv(patch_tokens)))
features = self.bn(self.conv(features))
features += self.shortcut(patch_tokens)
features = F.relu(features)
patch_tokens = F.avg_pool2d(features, 14).view(-1, self.token_dim)
cls_token = self.token_fc(cls_token)
out = patch_tokens + cls_token
return out
class VisionTransformer_hierarchical(VisionTransformer):
def __init__(self, num_classes, *args, **kwargs):
super().__init__(*args, **kwargs)
if num_classes == 100: #cifar100
self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
self.n_inputs = self.deit.head.in_features
self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
self.deit.load_state_dict(torch.load('../results/BlackBox/cifar100_deit_weights.pth'))
elif num_classes == 10:
self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
self.n_inputs = self.deit.head.in_features
self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
self.deit.load_state_dict(torch.load('../results/BlackBox/cifar10_deit_weights.pth'))
elif num_classes == 200:
self.deit = create_model('deit_tiny_patch16_224', pretrained=True)
self.n_inputs = self.deit.head.in_features
self.deit.head = torch.nn.Linear(self.n_inputs, num_classes)
self.deit.load_state_dict(torch.load('../results/BlackBox/tiny_deit_weights.pth'))
# Transformer heads
self.transformerheads = nn.Sequential(*[
TransformerHead(self.deit.embed_dim)
for i in range(11)])
def forward_features(self, x):
B, nc, w, h = x.shape
x = self.deit.patch_embed(x)
cls_tokens = self.deit.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.deit.pos_embed
x = self.deit.pos_drop(x)
# Store transformer outputs
transformerheads_outputs = []
for idx, blk in enumerate(self.deit.blocks):
x = blk(x)
if idx <= 10:
out = self.deit.norm(x)
out = self.transformerheads[idx](out)
transformerheads_outputs.append(out)
x = self.deit.norm(x)
return x, transformerheads_outputs
def forward(self, x):
x, transformerheads_outputs = self.forward_features(x)
output = []
for y in transformerheads_outputs:
output.append(self.deit.head(y))
output.append(self.deit.head(x[:, 0]))
return output