from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
import math
from os.path import join as pjoin
import torch
import torch.nn as nn
import numpy as np
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm, Softmax
from torch.nn.modules.utils import _pair
from scipy import ndimage
import TransformerConfigs_pretrain as configs
# from TransformerResNet import ResNetV2
import math
logger = logging.getLogger(__name__)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu}
class ViTEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
def __init__(self, config):
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values)
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# add positional encoding to each token
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings #[N, #tokens, hidden_size]
class ViTPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = (image_size, image_size)
patch_size = (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class ViTSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, output_attentions = False):
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs_presoftmax = attention_scores
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
attention_probs_presoftmax = self.dropout(attention_probs_presoftmax)
# print(attention_probs.shape) #[64, 12, 197, 197]
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
# outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = (context_layer, attention_probs_presoftmax) if output_attentions else (context_layer,)
return outputs #context_layer:[N, #tokens, #hidden_size]
#attention_probs:[N, #heads, #tokens, #tokens] #heads???
class ViTSelfOutput(nn.Module):
"""
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViTAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config)
self.pruned_heads = set()
def forward(self, hidden_states, output_attentions = False):
self_outputs = self.attention(hidden_states, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class ViTIntermediate(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config):
super().__init__()
self.attention = ViTAttention(config)
self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor, output_attentions = False):
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection
hidden_states = attention_output + hidden_states
# in ViT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ViTEncoder(nn.Module):
'''
multiple of ViTlayer
'''
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states, output_attentions = False, output_hidden_states = False, return_dict = True):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states: #skip
all_hidden_states = all_hidden_states + (hidden_states,)
# if self.gradient_checkpointing and self.training:
# def create_custom_forward(module):
# def custom_forward(*inputs):
# return module(*inputs, output_attentions)
# return custom_forward
# layer_outputs = torch.utils.checkpoint.checkpoint(
# create_custom_forward(layer_module),
# hidden_states,
# layer_head_mask,
# )
else:
layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states: #skip
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: #skip
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
# return BaseModelOutput(
# last_hidden_state=hidden_states,
# hidden_states=all_hidden_states,
# attentions=all_self_attentions,
# )
return (hidden_states, all_hidden_states, all_self_attentions)
#all_self_attentions: tuple(#layers)
#each attention: [N, #heads, #tokens, #tokens]
class ViTPooler(nn.Module):
'''
extract the CLS
'''
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output #[N, hidden_size]
class ViTModel_custom(nn.Module):
def __init__(self, config, add_pooling_layer=True, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]):
super().__init__()
self.config = config
self.embeddings = ViTEmbeddings(config)
self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
self.matrixMean = torch.ones(3, config.image_size, config.image_size)
self.matrixMean[0] = self.matrixMean[0]*mean[0] # [224, 224]
self.matrixMean[1] = self.matrixMean[1]*mean[1]
self.matrixMean[2] = self.matrixMean[2]*mean[2]
self.matrixStd = torch.ones(3, config.image_size, config.image_size)
self.matrixStd[0] = self.matrixStd[0]*std[0]
self.matrixStd[1] = self.matrixStd[1]*std[1]
self.matrixStd[2] = self.matrixStd[2]*std[2]
# self.matrixMean= self.matrixMean.cuda()
# self.matrixStd = self.matrixStd.cuda()
# # Initialize weights and apply final processing
# self.post_init()
def get_input_embeddings(self) -> ViTPatchEmbeddings:
return self.embeddings.patch_embeddings
def forward(
self,
pixel_values,
output_attentions = False,
output_hidden_states = False,
return_dict = True):
"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# output_hidden_states = (
# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
# )
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
# expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
# if pixel_values.dtype != expected_dtype:
# pixel_values = pixel_values.to(expected_dtype)
device_temp = pixel_values.device
self.matrixMean= self.matrixMean.to(device_temp)
self.matrixStd = self.matrixStd.to(device_temp)
pixel_values = (pixel_values-self.matrixMean)/self.matrixStd
embedding_output = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
# return BaseModelOutputWithPooling(
# last_hidden_state=sequence_output,
# pooler_output=pooled_output,
# hidden_states=encoder_outputs.hidden_states,
# attentions=encoder_outputs.attentions,
# )
# return (sequence_output, pooled_output, encoder_outputs.hidden_states, encoder_outputs.attentions)
if output_attentions==True:
return (sequence_output, pooled_output, encoder_outputs[-1])
else:
return (sequence_output, pooled_output)
#sequence_output:[N, #tokens, hidden_size]
#pooled_output:[N, hidden_size] / CLS
#encoder_outputs[-1]: attention map
# CONFIGS = {
# 'ViT-B_16': configs.get_b16_config(),
# 'ViT-B_32': configs.get_b32_config(),
# 'ViT-L_16': configs.get_l16_config(),
# 'ViT-L_32': configs.get_l32_config(),
# 'ViT-H_14': configs.get_h14_config(),
# 'R50-ViT-B_16': configs.get_r50_b16_config(),
# 'testing': configs.get_testing(),
# }