# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch import torch.nn as nn from torch.nn import functional as F from thirdparty.taming.transformer.mingpt import GPT class GPTWrapper(nn.Module): def __init__( self, device_id, shape, n_class, block_size, n_layer=12, n_head=8, n_embd=256, pkeep=1.0, ): super().__init__() self.pkeep=pkeep self.transformer = GPT( vocab_size=n_class, block_size=block_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd, ) self.n_embd = n_embd self.device_id = device_id self.shape = shape def forward(self, x, condition=None, cache=None): #print(x.shape) z_indices = torch.flatten(x, start_dim=1, end_dim=-1) if condition is not None: c_indices = torch.flatten(condition, start_dim=1, end_dim=-1) if self.training and self.pkeep < 1.0: mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, device=z_indices.device)) mask = mask.round().to(dtype=torch.int64) r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) a_indices = mask*z_indices+(1-mask)*r_indices else: a_indices = z_indices if condition is not None: cz_indices = torch.cat((c_indices, a_indices), dim=1) embeddings = None else: cz_indices = a_indices embeddings = torch.zeros(a_indices.size(0), 1, self.n_embd) if self.device_id != 'cpu': embeddings = embeddings.cuda(self.device_id) # target includes all sequence elements (no need to handle first one # differently because we are conditioning) # make the prediction logits, _ = self.transformer(cz_indices[:, :-1], embeddings=embeddings) # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c) if condition is not None: logits = logits[:, c_indices.shape[1]-1:] #print(logits.shape) #Match the output shape of pixelcnn (Only for compatibility) outshape = [logits.size(0)] + self.shape + [-1] out = torch.reshape(logits, outshape) out = out.permute(0, 3, 1, 2) loss = self.get_loss(out, x) return out, {}, loss def get_loss(self, logits, tgt): logits = logits.permute(0,2,3,1) flatten_logits = torch.reshape(logits, [-1, logits.size(-1)]) flatten_target = tgt.view(-1) return F.cross_entropy(flatten_logits, flatten_target)