mvq / models / gpt_wrap.py
gpt_wrap.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import torch.nn as nn
from torch.nn import functional as F

from thirdparty.taming.transformer.mingpt import GPT


class GPTWrapper(nn.Module):
    def __init__(
        self,
        device_id,
        shape,
        n_class,
        block_size,
        n_layer=12,
        n_head=8,
        n_embd=256,
        pkeep=1.0,
    ):
        super().__init__()
        self.pkeep=pkeep
        self.transformer = GPT(
            vocab_size=n_class, 
            block_size=block_size, 
            n_layer=n_layer, 
            n_head=n_head, 
            n_embd=n_embd,
        )
        self.n_embd = n_embd
        self.device_id = device_id
        self.shape = shape
    
    def forward(self, x, condition=None, cache=None):
        #print(x.shape)
        z_indices = torch.flatten(x, start_dim=1, end_dim=-1)
        if condition is not None:
            c_indices = torch.flatten(condition, start_dim=1, end_dim=-1)

        if self.training and self.pkeep < 1.0:
            mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
                                                         device=z_indices.device))
            mask = mask.round().to(dtype=torch.int64)
            r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
            a_indices = mask*z_indices+(1-mask)*r_indices
        else:
            a_indices = z_indices
        
        if condition is not None:
            cz_indices = torch.cat((c_indices, a_indices), dim=1)
            embeddings = None
        else:
            cz_indices = a_indices
            embeddings = torch.zeros(a_indices.size(0), 1, self.n_embd)
            if self.device_id != 'cpu':
                embeddings = embeddings.cuda(self.device_id)

        # target includes all sequence elements (no need to handle first one
        # differently because we are conditioning)
        # make the prediction
        logits, _ = self.transformer(cz_indices[:, :-1], embeddings=embeddings)

        # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
        if condition is not None:
            logits = logits[:, c_indices.shape[1]-1:]

        #print(logits.shape)
        #Match the output shape of pixelcnn (Only for compatibility)
        outshape = [logits.size(0)] + self.shape + [-1]
        out = torch.reshape(logits, outshape)
        out = out.permute(0, 3, 1, 2)

        loss = self.get_loss(out, x)

        return out, {}, loss

    def get_loss(self, logits, tgt):
        logits = logits.permute(0,2,3,1)
        flatten_logits = torch.reshape(logits, [-1, logits.size(-1)])
        flatten_target = tgt.view(-1)
        return F.cross_entropy(flatten_logits, flatten_target)