from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import logging import math from os.path import join as pjoin import torch import torch.nn as nn import numpy as np from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm, Softmax from torch.nn.modules.utils import _pair from scipy import ndimage import TransformerConfigs_pretrain as configs # from TransformerResNet import ResNetV2 import math logger = logging.getLogger(__name__) ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu} class ViTEmbeddings(nn.Module): """ Construct the CLS token, position and patch embeddings. Optionally, also the mask token. """ def __init__(self, config): super().__init__() self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.patch_embeddings = ViTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape embeddings = self.patch_embeddings(pixel_values) # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings #[N, #tokens, hidden_size] class ViTPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = (image_size, image_size) patch_size = (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size[0]}*{self.image_size[1]})." ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings class ViTSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, output_attentions = False): key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(self.query(hidden_states)) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) attention_probs_presoftmax = attention_scores # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) attention_probs_presoftmax = self.dropout(attention_probs_presoftmax) # print(attention_probs.shape) #[64, 12, 197, 197] context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) # outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs_presoftmax) if output_attentions else (context_layer,) return outputs #context_layer:[N, #tokens, #hidden_size] #attention_probs:[N, #heads, #tokens, #tokens] #heads??? class ViTSelfOutput(nn.Module): """ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the layernorm applied before each block. """ def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class ViTAttention(nn.Module): def __init__(self, config): super().__init__() self.attention = ViTSelfAttention(config) self.output = ViTSelfOutput(config) self.pruned_heads = set() def forward(self, hidden_states, output_attentions = False): self_outputs = self.attention(hidden_states, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class ViTIntermediate(nn.Module): def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class ViTOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = hidden_states + input_tensor return hidden_states class ViTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config): super().__init__() self.attention = ViTAttention(config) self.intermediate = ViTIntermediate(config) self.output = ViTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor, output_attentions = False): self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection hidden_states = attention_output + hidden_states # in ViT, layernorm is also applied after self-attention layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) # second residual connection is done here layer_output = self.output(layer_output, hidden_states) outputs = (layer_output,) + outputs return outputs class ViTEncoder(nn.Module): ''' multiple of ViTlayer ''' def __init__(self, config) -> None: super().__init__() self.config = config self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward(self, hidden_states, output_attentions = False, output_hidden_states = False, return_dict = True): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: #skip all_hidden_states = all_hidden_states + (hidden_states,) # if self.gradient_checkpointing and self.training: # def create_custom_forward(module): # def custom_forward(*inputs): # return module(*inputs, output_attentions) # return custom_forward # layer_outputs = torch.utils.checkpoint.checkpoint( # create_custom_forward(layer_module), # hidden_states, # layer_head_mask, # ) else: layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: #skip all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: #skip return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) # return BaseModelOutput( # last_hidden_state=hidden_states, # hidden_states=all_hidden_states, # attentions=all_self_attentions, # ) return (hidden_states, all_hidden_states, all_self_attentions) #all_self_attentions: tuple(#layers) #each attention: [N, #heads, #tokens, #tokens] class ViTPooler(nn.Module): ''' extract the CLS ''' def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output #[N, hidden_size] class ViTModel_custom(nn.Module): def __init__(self, config, add_pooling_layer=True, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]): super().__init__() self.config = config self.embeddings = ViTEmbeddings(config) self.encoder = ViTEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pooler = ViTPooler(config) if add_pooling_layer else None self.matrixMean = torch.ones(3, config.image_size, config.image_size) self.matrixMean[0] = self.matrixMean[0]*mean[0] # [224, 224] self.matrixMean[1] = self.matrixMean[1]*mean[1] self.matrixMean[2] = self.matrixMean[2]*mean[2] self.matrixStd = torch.ones(3, config.image_size, config.image_size) self.matrixStd[0] = self.matrixStd[0]*std[0] self.matrixStd[1] = self.matrixStd[1]*std[1] self.matrixStd[2] = self.matrixStd[2]*std[2] # self.matrixMean= self.matrixMean.cuda() # self.matrixStd = self.matrixStd.cuda() # # Initialize weights and apply final processing # self.post_init() def get_input_embeddings(self) -> ViTPatchEmbeddings: return self.embeddings.patch_embeddings def forward( self, pixel_values, output_attentions = False, output_hidden_states = False, return_dict = True): """ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # output_hidden_states = ( # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # ) # return_dict = return_dict if return_dict is not None else self.config.use_return_dict # # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) # expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype # if pixel_values.dtype != expected_dtype: # pixel_values = pixel_values.to(expected_dtype) device_temp = pixel_values.device self.matrixMean= self.matrixMean.to(device_temp) self.matrixStd = self.matrixStd.to(device_temp) pixel_values = (pixel_values-self.matrixMean)/self.matrixStd embedding_output = self.embeddings(pixel_values) encoder_outputs = self.encoder( embedding_output, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) return head_outputs + encoder_outputs[1:] # return BaseModelOutputWithPooling( # last_hidden_state=sequence_output, # pooler_output=pooled_output, # hidden_states=encoder_outputs.hidden_states, # attentions=encoder_outputs.attentions, # ) # return (sequence_output, pooled_output, encoder_outputs.hidden_states, encoder_outputs.attentions) if output_attentions==True: return (sequence_output, pooled_output, encoder_outputs[-1]) else: return (sequence_output, pooled_output) #sequence_output:[N, #tokens, hidden_size] #pooled_output:[N, hidden_size] / CLS #encoder_outputs[-1]: attention map # CONFIGS = { # 'ViT-B_16': configs.get_b16_config(), # 'ViT-B_32': configs.get_b32_config(), # 'ViT-L_16': configs.get_l16_config(), # 'ViT-L_32': configs.get_l32_config(), # 'ViT-H_14': configs.get_h14_config(), # 'R50-ViT-B_16': configs.get_r50_b16_config(), # 'testing': configs.get_testing(), # }