honeyplotnet / models / continuous / disc.py
disc.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------


import torch
import torch.nn as nn
import torch.nn.functional as F

from .encoder import Encoder


def adopt_weight(weight, global_step, threshold=0, value=0.):
    if global_step < threshold:
        weight = value
    return weight


def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1. - logits_real))
    loss_fake = torch.mean(F.relu(1. + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def vanilla_d_loss(logits_real, logits_fake):
    d_loss = 0.5 * (
        torch.mean(torch.nn.functional.softplus(-logits_real)) +
        torch.mean(torch.nn.functional.softplus(logits_fake)))
    return d_loss

class Discriminator(nn.Module):
  def __init__(self, 
    enc_conv, enc_tf, disc_loss='hinge', emb_dim1=8,  use_pos_embs=True,
    max_series_blocks=5, max_cont_blocks=32, norm_mode='minmax',
    disc_factor=1.0, disc_weight=1.0, disc_conditional=False, disc_start=100,
    device='cpu', fp16=False, debug=False):
    super().__init__()
    assert disc_loss in ["hinge", "vanilla"]
    
    self.debug        = debug
    self.device       = device 
    self.use_fp16     = fp16
    self.dtype_float  = torch.float32 
    self.dtype_int    = torch.int32 

    self.norm_mode    = norm_mode

    self.discriminator = Encoder(
      enc_conv=enc_conv,
      enc_tf=enc_tf,
      out_dim=emb_dim1,
      max_series_blocks=max_series_blocks,
      max_cont_blocks=max_cont_blocks,
      use_pos_embs=use_pos_embs,
      norm_mode=norm_mode,
      chart_type_conditional=disc_conditional,
      device=device,
      debug=debug
    )

    self.discriminator_iter_start = disc_start

    if disc_loss == "hinge":
        self.disc_loss = hinge_d_loss
    elif disc_loss == "vanilla":
        self.disc_loss = vanilla_d_loss
    else:
        raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
    
    self.disc_factor = disc_factor
    self.discriminator_weight = disc_weight
    self.disc_conditional = disc_conditional    
      
  def forward(self, loss_dict, inputs, reconstructions, optimizer_idx, global_step):

    if optimizer_idx == 0:
      # generator update
      logits_fake, na_loss = self.discriminator(reconstructions)
      g_loss = -torch.mean(logits_fake)

      disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
      loss = disc_factor * g_loss + na_loss
      return {'gen': loss}, {}
    
    else:
      
      logits_real, na1_loss = self.discriminator(self.preprocess_for_disc(inputs))
      logits_fake, na2_loss = self.discriminator(self.preprocess_for_disc(reconstructions))

      disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
      #zero_loss = 0.0 * sum([v for v in loss_dict.values()])
      d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + na1_loss + na2_loss

      return {'disc': d_loss}, {}
    
  def preprocess_for_disc(self, x):
    output = {'scale': {}, 'continuous': {}}
    
    for k, v in x.items():

      if k in ['scale','continuous']:
        for kk, vv in v.items():
          if isinstance(vv, list):
            output[k][kk] = [vvv.contiguous().detach() for vvv in vv]
          elif isinstance(vv, torch.Tensor):
            output[k][kk] = vv.contiguous().detach()

      else:
        output[k] = v

    return output