# import logging
import torch
from torch import nn
import numpy as np
import collections.abc
import math
from copy import deepcopy
# logger = logging.get_logger(__name__)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu}
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): #embed_dim=768, grid_size=14
"""
Create 2D sin/cos positional embeddings.
Args:
embed_dim (`int`):
Embedding dimension.
grid_size (`int`):
The grid height and width.
add_cls_token (`bool`, *optional*, defaults to `False`):
Whether or not to add a classification (CLS) token.
Returns:
(`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
position embeddings (with or without classification token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) #[2, grid_size, grid_size]
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) #[num_tokens, hidden_size]
if add_cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,) #(192,)
pos = pos.reshape(-1) # (M,) 196
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class ViTMAEEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings.
"""
def __init__(self, config):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = ViTMAEPatchEmbeddings(config)
self.num_patches = self.patch_embeddings.num_patches
# fixed sin-cos embedding
self.position_embeddings = nn.Parameter(
torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
)
self.config = config
self.initialize_weights()
def initialize_weights(self):
# initialize (and freeze) position embeddings by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(
self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
)
self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
w = self.patch_embeddings.projection.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
def random_masking(self, sequence, noise=None, random_seed=None, preset_mask=False):
"""
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
noise.
Args:
sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
mainly used for testing purposes to control randomness and maintain the reproducibility
"""
batch_size, seq_length, dim = sequence.shape # seq_length is #tokens
len_keep = int(seq_length * (1 - self.config.mask_ratio))
if noise is None:
if random_seed is not None:
torch.manual_seed(random_seed)
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep] #[N, #tokens*ratio]
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) #index: [7,49,768]
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([batch_size, seq_length], device=sequence.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
if preset_mask:
mask = 1-mask
return sequence_unmasked, mask, ids_restore
# mask: [N, #tokens]
# ids_restore: [N, #tokens]
# sequence_unmasked: [N, #tokens*0.75, hidden_size]
def salient_masking(self, sequence, noise=None, random_seed=None, preset_mask=False):
batch_size, seq_length, dim = sequence.shape # seq_length is #tokens
row_indices, col_indices = np.where(preset_mask == 0)
row_indices_1, col_indices_1 = np.where(preset_mask == 1)
mask = torch.from_numpy(preset_mask).to(sequence.device)
len_keep = int(seq_length * (1 - self.config.mask_ratio))
len_remove = int(seq_length * self.config.mask_ratio)
positions = np.zeros((batch_size, len_keep), dtype=int)
tracker = 0
for i in range(batch_size):
positions[row_indices[tracker]] = col_indices[tracker:tracker+len_keep]
tracker += len_keep
positions_1 = np.zeros((batch_size, len_remove), dtype=int)
tracker = 0
for i in range(batch_size):
positions_1[row_indices_1[tracker]] = col_indices_1[tracker:tracker+len_remove]
tracker += len_remove
sequence_unmasked = sequence[torch.arange(batch_size).unsqueeze(1), positions, :]
ids_shuffle = np.concatenate((positions, positions_1), axis=1)
ids_shuffle = torch.from_numpy(ids_shuffle)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_restore = ids_restore.to(sequence.device)
return sequence_unmasked, mask, ids_restore
def forward(self, pixel_values, noise=None, random_seed=None, preset_mask=False):
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values)
# add position embeddings w/o cls token
embeddings = embeddings + self.position_embeddings[:, 1:, :] #[N, #tokens, hidden_size]
# masking: length -> length * config.mask_ratio
if preset_mask is False:
embeddings, mask, ids_restore = self.random_masking(embeddings, noise, random_seed, preset_mask=preset_mask)
else:
embeddings, mask, ids_restore = self.salient_masking(embeddings, noise, random_seed, preset_mask=preset_mask)
# append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :] #[1,1,hidden_size]
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1) #[N, #tokens*0.75+1, hidden_size]
return embeddings, mask, ids_restore
# mask: [N, #tokens]
# ids_restore: [N, #tokens]
class ViTMAEPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
class ViTMAESelfAttention(nn.Module):
def __init__(self, config) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask = None, output_attentions = False):
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs) #[N, #heads, #tokens, #tokens]
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
class ViTMAESelfOutput(nn.Module):
"""
The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
class ViTMAEAttention(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.attention = ViTMAESelfAttention(config)
self.output = ViTMAESelfOutput(config)
self.pruned_heads = set()
def forward(self, hidden_states, head_mask, output_attentions):
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
#outputs[0]: [N, #tokens, hidden_size]
#outputs[1]: [N, #heads, #tokens, #tokens]
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
class ViTMAEIntermediate(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTMAE
class ViTMAEOutput(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE
class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTMAEAttention(config)
self.intermediate = ViTMAEIntermediate(config)
self.output = ViTMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, head_mask = None, output_attentions = False):
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection
hidden_states = attention_output + hidden_states
# in ViTMAE, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
#outputs[0]: [N, #tokens, hidden_size]
#outputs[1]: [N, #heads, #tokens, #tokens]
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
class ViTMAEEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states, head_mask = None, output_attentions = False, output_hidden_states = False, return_dict = True):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None ##
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
# return BaseModelOutput(
# last_hidden_state=hidden_states,
# hidden_states=all_hidden_states,
# attentions=all_self_attentions,
# )
return (hidden_states, all_hidden_states, all_self_attentions)
# hidden_states: [N, #tokens, #hidden_size]
# all_self_attentions: [#layers, N, #heads, #tokens, #tokens]
class ViTMAEModel_custom(nn.Module):
def __init__(self, config):
# super().__init__(config)
super().__init__()
self.config = config
self.embeddings = ViTMAEEmbeddings(config)
self.encoder = ViTMAEEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# # Initialize weights and apply final processing
# self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(self, pixel_values, noise = None, random_seed=None, head_mask = None, output_attentions = None, output_hidden_states = None, return_dict = None, preset_mask=False):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise, random_seed=random_seed, preset_mask=preset_mask)
encoder_outputs = self.encoder(embedding_output, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
if not return_dict:
return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
return (sequence_output, mask, ids_restore, encoder_outputs[1], encoder_outputs[2])
#sequence_output: [N, #tokens*0.75, hidden_size]
#mask: [N, #tokens] 0/1
#ids_restore: [N, #tokens]
#encoder_outputs[1]: all_hidden_states
#encoder_outputs[2]: all_self_attentions [#layers, N, #heads, #tokens*0.75, #tokens*0.75]
class ViTMAEDecoder_custom(nn.Module):
def __init__(self, config, num_patches):
super().__init__()
self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False
) # fixed sin-cos embedding
decoder_config = deepcopy(config)
decoder_config.hidden_size = config.decoder_hidden_size
decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
decoder_config.num_attention_heads = config.decoder_num_attention_heads
decoder_config.intermediate_size = config.decoder_intermediate_size
self.decoder_layers = nn.ModuleList(
[ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
)
self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.decoder_pred = nn.Linear(
config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
) # encoder to decoder
self.gradient_checkpointing = False
self.config = config
self.initialize_weights(num_patches)
def initialize_weights(self, num_patches):
# initialize (and freeze) position embeddings by sin-cos embedding
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)
def forward(self, hidden_states, ids_restore, output_attentions=False, output_hidden_states=False, return_dict=True):
# embed tokens
x = self.decoder_embed(hidden_states) #[N, #tokens*0.75, decoder_hidden_size]
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) #[N, #masked_tokens, decoder_hidden_size]
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token #[N, #tokens, decoer_hidden_size]
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
hidden_states = x + self.decoder_pos_embed
# apply Transformer layers (blocks)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.decoder_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
else:
layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.decoder_norm(hidden_states)
# predictor projection
logits = self.decoder_pred(hidden_states)
# remove cls token
logits = logits[:, 1:, :]
if not return_dict:
return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
return (logits, all_hidden_states, all_self_attentions)
class ViTMAEForPreTraining_salient(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.vit = ViTMAEModel_custom(config)
self.decoder = ViTMAEDecoder_custom(config, num_patches=self.vit.embeddings.num_patches)
def get_input_embeddings(self):
return self.vit.embeddings.patch_embeddings
def patchify(self, pixel_values):
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
# sanity checks
if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
if pixel_values.shape[1] != num_channels:
raise ValueError(
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
)
# patchify
batch_size = pixel_values.shape[0]
num_patches_one_direction = pixel_values.shape[2] // patch_size
patchified_pixel_values = pixel_values.reshape(
batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
)
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
)
return patchified_pixel_values
def unpatchify(self, patchified_pixel_values):
"""
Args:
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
Pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
# sanity check
if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
raise ValueError("Make sure that the number of patches can be squared")
# unpatchify
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_one_direction,
num_patches_one_direction,
patch_size,
patch_size,
num_channels,
)
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_one_direction * patch_size,
num_patches_one_direction * patch_size,
)
return pixel_values
def forward_loss(self, pixel_values, pred, mask):
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Predicted pixel values.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
Returns:
`torch.FloatTensor`: Pixel reconstruction loss.
"""
target = self.patchify(pixel_values)
if self.config.norm_pix_loss: #False
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def forward(self, pixel_values, noise = None, random_seed = None,
head_mask = None, output_attentions = False,
output_hidden_states = False, return_dict = True, preset_mask=False):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values,
noise=noise,
random_seed = random_seed,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
preset_mask=preset_mask
)
# latent = outputs.last_hidden_state
# ids_restore = outputs.ids_restore
# mask = outputs.mask
latent = outputs[0]
mask = outputs[1]
ids_restore = outputs[2]
decoder_outputs = self.decoder(latent, ids_restore)
# logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
logits = decoder_outputs[0]
loss = self.forward_loss(pixel_values, logits, mask)
return (loss, logits, mask, ids_restore, outputs[1], outputs[2]) #wrong
#logits.shape: [N, #tokens, hidden-size]
class ViTMAEForPreTraining_classify(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.vit = ViTMAEModel_custom(config)
self.classifier = nn.Linear(768, 200)
def forward(self, pixel_values, noise = None, random_seed = None,
head_mask = None, output_attentions = False,
output_hidden_states = False, return_dict = True, preset_mask=False):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values,
noise=noise,
random_seed = random_seed,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
preset_mask=preset_mask
)
latent = outputs[0][:,0]
decoder_outputs = self.classifier(latent)
return decoder_outputs