Generative-Adversarial-Network / gan / losses.py
losses.py
Raw
import torch
from torch.nn.functional import binary_cross_entropy_with_logits as bce_loss
import torch.nn as nn


def discriminator_loss(logits_real, logits_fake):
    """ output_discriminator, all_samples_labels
    Computes the discriminator loss.
    
    You should use the stable torch.nn.functional.binary_cross_entropy_with_logits 
    loss rather than using a separate softmax function followed by the binary cross
    entropy loss.
    
    Inputs:
    - logits_real: PyTorch Tensor of shape (N,) giving scores for the real data.
    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the discriminator.
    """
    
    loss = None
    #print("loss function  fake size",logits_fake.size())
    #print("loss function  real size",logits_real.size())
    ####################################
    #          YOUR CODE HERE          #
    ####################################
    #torch.nn.functional.binary_cross_entropy_with_logits(input: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None) → torch.Tensor
   
    
    
    input_size=logits_real.size() # both real and fake are of same size.
    
    #torch.ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
    #torch.zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
    ones=torch.ones(input_size, dtype=logits_real.dtype, layout=logits_real.layout, device=logits_real.device) #getting errors that input and target not on the same device
    zeros=torch.zeros(input_size, dtype=logits_fake.dtype, layout=logits_fake.layout, device=logits_fake.device) #getting errors that input and target not on the same device
    loss_1 =bce_loss(logits_real, ones,reduction='mean')   #Default: 'mean'
    loss_2=bce_loss(logits_fake, zeros,reduction='mean')   #Default: 'mean'
    loss=(loss_1+loss_2)/2  #so make sure to combine the loss by averaging instead of summing.
    
    
    ##########       END      ##########
    
    return loss

def generator_loss(logits_fake):
    """
    Computes the generator loss.
    
    You should use the stable torch.nn.functional.binary_cross_entropy_with_logits 
    loss rather than using a separate softmax function followed by the binary cross
    entropy loss.

    Inputs:
    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Returns:
    - loss: PyTorch Tensor containing the (scalar) loss for the generator.
    """
    
    loss = None
    
    ####################################
    #          YOUR CODE HERE          #
    ####################################

    input_size=logits_fake.size()
    
    #torch.ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
    #getting errors that input and target not on the same device
    ones=torch.ones(input_size, dtype=logits_fake.dtype, layout=logits_fake.layout, device=logits_fake.device)
    

    loss =bce_loss(logits_fake, ones,reduction='mean') #bce loss imported above
                                        #Default: 'mean'
    ##########       END      ##########
    
    return loss


def ls_discriminator_loss(scores_real, scores_fake):
    """
    Compute the Least-Squares GAN loss for the discriminator.
    
    Inputs:
    - scores_real: PyTorch Tensor of shape (N,) giving scores for the real data.
    - scores_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Outputs:
    - loss: A PyTorch Tensor containing the loss.
    """
    
    loss = None
    
    ####################################
    #          YOUR CODE HERE          #
    ####################################
    input_size=scores_real.size()
    ones=torch.ones(input_size, dtype=scores_real.dtype, layout=scores_real.layout, device=scores_real.device) #getting errors that input and target not on the same device
    # #torch.ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
    
    #print("input size",input_size)
    #print(scores_fake.size(0))
    #loss=loss/scores_fake.size(0)
    loss_1 = torch.mean((1/2)*((scores_real - ones)**2)) #piazza advice to use torch.mean()
    loss_2 = torch.mean((1/2)*(scores_fake**2))    #piazza advice to use torch.mean()
    loss =  loss_1 + loss_2 #)/2   
    
    ##########       END      ##########
    
    return loss

def ls_generator_loss(scores_fake):
    """
    Computes the Least-Squares GAN loss for the generator.
    
    Inputs:
    - scores_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Outputs:
    - loss: A PyTorch Tensor containing the loss.
    """
    
    loss = None
    
    ####################################
    #          YOUR CODE HERE          #
    ####################################
    input_size=scores_fake.size()
    # #torch.ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
    ones=torch.ones(input_size, dtype=scores_fake.dtype, layout=scores_fake.layout, device=scores_fake.device)    #getting errors that input and target not on the same device 
    #piazza advice to use torch.mean()
    loss =  torch.mean((1/2) *((scores_fake - ones)**2) )
    #print("input size",input_size)
    #print(scores_fake.size(0))
    #loss=loss/scores_fake.size(0)
    ##########       END      ##########
    
    return loss