# The following code has been modified based on the source provided in the referenced link. https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/vit_mae/modeling_vit_mae.py # import logging import torch from torch import nn import numpy as np import collections.abc import math from copy import deepcopy # logger = logging.get_logger(__name__) ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu} def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): #embed_dim=768, grid_size=14 """ Create 2D sin/cos positional embeddings. Args: embed_dim (`int`): Embedding dimension. grid_size (`int`): The grid height and width. add_cls_token (`bool`, *optional*, defaults to `False`): Whether or not to add a classification (CLS) token. Returns: (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position embeddings (with or without classification token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) #[2, grid_size, grid_size] grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) #[num_tokens, hidden_size] if add_cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be even") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be even") omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) #(192,) pos = pos.reshape(-1) # (M,) 196 out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class ViTMAEEmbeddings(nn.Module): """ Construct the CLS token, position and patch embeddings. """ def __init__(self, config): super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.patch_embeddings = ViTMAEPatchEmbeddings(config) self.num_patches = self.patch_embeddings.num_patches # fixed sin-cos embedding self.position_embeddings = nn.Parameter( torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False ) self.config = config self.initialize_weights() def initialize_weights(self): # initialize (and freeze) position embeddings by sin-cos embedding pos_embed = get_2d_sincos_pos_embed( self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True ) self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) w = self.patch_embeddings.projection.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range) def random_masking(self, sequence, noise=None, random_seed=None, preset_mask=False): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. Args: sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`) noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is mainly used for testing purposes to control randomness and maintain the reproducibility """ batch_size, seq_length, dim = sequence.shape # seq_length is #tokens len_keep = int(seq_length * (1 - self.config.mask_ratio)) if noise is None: if random_seed is not None: torch.manual_seed(random_seed) noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] #[N, #tokens*ratio] sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) #index: [7,49,768] # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([batch_size, seq_length], device=sequence.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) if preset_mask: mask = 1-mask return sequence_unmasked, mask, ids_restore # mask: [N, #tokens] # ids_restore: [N, #tokens] # sequence_unmasked: [N, #tokens*0.75, hidden_size] def forward(self, pixel_values, noise=None, random_seed=None, preset_mask=False): batch_size, num_channels, height, width = pixel_values.shape embeddings = self.patch_embeddings(pixel_values) # add position embeddings w/o cls token embeddings = embeddings + self.position_embeddings[:, 1:, :] #[N, #tokens, hidden_size] # masking: length -> length * config.mask_ratio embeddings, mask, ids_restore = self.random_masking(embeddings, noise, random_seed, preset_mask=preset_mask) # embeddings.shape [N, #tokens*0.75, hidden_size] # append cls token cls_token = self.cls_token + self.position_embeddings[:, :1, :] #[1,1,hidden_size] cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) #[N, #tokens*0.75+1, hidden_size] return embeddings, mask, ids_restore # mask: [N, #tokens] # ids_restore: [N, #tokens] class ViTMAEPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) x = self.projection(pixel_values).flatten(2).transpose(1, 2) return x # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE class ViTMAESelfAttention(nn.Module): def __init__(self, config) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, head_mask = None, output_attentions = False): mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) #[N, #heads, #tokens, #tokens] # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE class ViTMAESelfOutput(nn.Module): """ The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the layernorm applied before each block. """ def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE class ViTMAEAttention(nn.Module): def __init__(self, config) -> None: super().__init__() self.attention = ViTMAESelfAttention(config) self.output = ViTMAESelfOutput(config) self.pruned_heads = set() def forward(self, hidden_states, head_mask, output_attentions): self_outputs = self.attention(hidden_states, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs #outputs[0]: [N, #tokens, hidden_size] #outputs[1]: [N, #heads, #tokens, #tokens] # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE class ViTMAEIntermediate(nn.Module): def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states # Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTMAE class ViTMAEOutput(nn.Module): def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = hidden_states + input_tensor return hidden_states # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE class ViTMAELayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = ViTMAEAttention(config) self.intermediate = ViTMAEIntermediate(config) self.output = ViTMAEOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, head_mask = None, output_attentions = False): self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection hidden_states = attention_output + hidden_states # in ViTMAE, layernorm is also applied after self-attention layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) # second residual connection is done here layer_output = self.output(layer_output, hidden_states) outputs = (layer_output,) + outputs return outputs #outputs[0]: [N, #tokens, hidden_size] #outputs[1]: [N, #heads, #tokens, #tokens] # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE class ViTMAEEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward(self, hidden_states, head_mask = None, output_attentions = False, output_hidden_states = False, return_dict = True): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] if head_mask is not None else None ## if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return (hidden_states, all_hidden_states, all_self_attentions) # hidden_states: [N, #tokens, #hidden_size] # all_self_attentions: [#layers, N, #heads, #tokens, #tokens] class ViTMAEModel_custom(nn.Module): def __init__(self, config): # super().__init__(config) super().__init__() self.config = config self.embeddings = ViTMAEEmbeddings(config) self.encoder = ViTMAEEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # # Initialize weights and apply final processing # self.post_init() def get_input_embeddings(self): return self.embeddings.patch_embeddings def forward(self, pixel_values, noise = None, random_seed=None, head_mask = None, output_attentions = None, output_hidden_states = None, return_dict = None, preset_mask=False): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] if head_mask is not None: head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise, random_seed=random_seed, preset_mask=preset_mask) encoder_outputs = self.encoder(embedding_output, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) if not return_dict: return (sequence_output, mask, ids_restore) + encoder_outputs[1:] return (sequence_output, mask, ids_restore, encoder_outputs[1], encoder_outputs[2]) #sequence_output: [N, #tokens*0.75, hidden_size] #mask: [N, #tokens] 0/1 #ids_restore: [N, #tokens] #encoder_outputs[1]: all_hidden_states #encoder_outputs[2]: all_self_attentions [#layers, N, #heads, #tokens*0.75, #tokens*0.75] class ViTMAEDecoder_custom(nn.Module): def __init__(self, config, num_patches): super().__init__() self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size)) self.decoder_pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False ) # fixed sin-cos embedding decoder_config = deepcopy(config) decoder_config.hidden_size = config.decoder_hidden_size decoder_config.num_hidden_layers = config.decoder_num_hidden_layers decoder_config.num_attention_heads = config.decoder_num_attention_heads decoder_config.intermediate_size = config.decoder_intermediate_size self.decoder_layers = nn.ModuleList( [ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)] ) self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps) self.decoder_pred = nn.Linear( config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True ) # encoder to decoder self.gradient_checkpointing = False self.config = config self.initialize_weights(num_patches) def initialize_weights(self, num_patches): # initialize (and freeze) position embeddings by sin-cos embedding decoder_pos_embed = get_2d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True ) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range) def forward(self, hidden_states, ids_restore, output_attentions=False, output_hidden_states=False, return_dict=True): # embed tokens x = self.decoder_embed(hidden_states) #[N, #tokens*0.75, decoder_hidden_size] # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) #[N, #masked_tokens, decoder_hidden_size] x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token #[N, #tokens, decoer_hidden_size] x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed hidden_states = x + self.decoder_pos_embed # apply Transformer layers (blocks) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for i, layer_module in enumerate(self.decoder_layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) hidden_states = self.decoder_norm(hidden_states) # predictor projection logits = self.decoder_pred(hidden_states) # remove cls token logits = logits[:, 1:, :] if not return_dict: return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None) return (logits, all_hidden_states, all_self_attentions) class ViTMAEForPreTraining_custom(nn.Module): def __init__(self, config): super().__init__() self.config = config self.vit = ViTMAEModel_custom(config) self.decoder = ViTMAEDecoder_custom(config, num_patches=self.vit.embeddings.num_patches) # # Initialize weights and apply final processing # self.post_init() def get_input_embeddings(self): return self.vit.embeddings.patch_embeddings def patchify(self, pixel_values): """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Returns: `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: Patchified pixel values. """ patch_size, num_channels = self.config.patch_size, self.config.num_channels # sanity checks if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0): raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size") if pixel_values.shape[1] != num_channels: raise ValueError( "Make sure the number of channels of the pixel values is equal to the one set in the configuration" ) # patchify batch_size = pixel_values.shape[0] num_patches_one_direction = pixel_values.shape[2] // patch_size patchified_pixel_values = pixel_values.reshape( batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size ) patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values) patchified_pixel_values = patchified_pixel_values.reshape( batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels ) return patchified_pixel_values def unpatchify(self, patchified_pixel_values): """ Args: patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: Patchified pixel values. Returns: `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: Pixel values. """ patch_size, num_channels = self.config.patch_size, self.config.num_channels num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5) # sanity check if num_patches_one_direction**2 != patchified_pixel_values.shape[1]: raise ValueError("Make sure that the number of patches can be squared") # unpatchify batch_size = patchified_pixel_values.shape[0] patchified_pixel_values = patchified_pixel_values.reshape( batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels, ) patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) pixel_values = patchified_pixel_values.reshape( batch_size, num_channels, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, ) return pixel_values def forward_loss(self, pixel_values, pred, mask): """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: Predicted pixel values. mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Tensor indicating which patches are masked (1) and which are not (0). Returns: `torch.FloatTensor`: Pixel reconstruction loss. """ target = self.patchify(pixel_values) if self.config.norm_pix_loss: #False mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.0e-6) ** 0.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss def forward(self, pixel_values, noise = None, random_seed = None, head_mask = None, output_attentions = False, output_hidden_states = False, return_dict = True, preset_mask=False): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.vit( pixel_values, noise=noise, random_seed = random_seed, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, preset_mask=preset_mask ) latent = outputs[0] mask = outputs[1] ids_restore = outputs[2] decoder_outputs = self.decoder(latent, ids_restore) # logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) logits = decoder_outputs[0] loss = self.forward_loss(pixel_values, logits, mask) return (loss, logits, mask, ids_restore, outputs[1], outputs[2]) #wrong #logits.shape: [N, #tokens, hidden-size] class ViTMAEForPreTraining_classify(nn.Module): def __init__(self, config): super().__init__() self.config = config self.vit = ViTMAEModel_custom(config) self.classifier = nn.Linear(768, 200) def forward(self, pixel_values, noise = None, random_seed = None, head_mask = None, output_attentions = False, output_hidden_states = False, return_dict = True, preset_mask=False): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.vit( pixel_values, noise=noise, random_seed = random_seed, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, preset_mask=preset_mask ) latent = outputs[0][:,0] decoder_outputs = self.classifier(latent) return decoder_outputs