EECE571F-project / ml4co-competition / submissions / dual / training / algorithms.py
algorithms.py
Raw
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
from torchrl.modules import MaskedCategorical
import torch.distributions as D
import numpy as np
from submissions.dual.training.utils import Transition, GraphDataset

def compute_gae(rewards, values, dones, gamma=1.0, lam=0.95):
    values = torch.cat([values, torch.zeros(1, device=values.device)], dim=0)

    advantages = torch.zeros_like(rewards, device='cpu')
    returns_to_go = torch.zeros_like(rewards, device='cpu')

    last_gae = 0  # GAE for the final timestep in an episode
    last_return = 0 # return

    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] # td-error
        advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae # gae advantage

        returns_to_go[t] = last_return = rewards[t] + gamma*last_return
    # returns_to_go = advantages + values[:-1]
    return advantages, returns_to_go

def compute_discounted_returns(rewards, gamma=0.99):
    rewards = rewards.float()
    discounts = gamma ** torch.arange(len(rewards)).float()
    discounted_rewards = rewards * discounts
    discounted_returns = torch.flip(torch.cumsum(torch.flip(discounted_rewards, [0]), dim=0), [0])
    return discounted_returns / discounts

def decayed_learning_rate(step):
    initial_learning_rate = 0.05
    decay_rate = 0.1
    decay_steps = 100
    return initial_learning_rate * (decay_rate ** (step / decay_steps))


def GAE(actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer,  critic_opt: torch.optim.Optimizer, replay_buffer, step):
    actor.train()
    critic_main.train()

    
    transitions = replay_buffer.memory
    batch_size1 = len(transitions)
    batch_size2 = 16
    batch = Transition(*zip(*transitions))
    dataset = BipartiteDatasetRL(*batch)
    dataloader1 = DataLoader(dataset, batch_size=batch_size1, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
    for SASR in dataloader1:
        graph, action, reward, done, _ = SASR
        # reward = (reward - reward.mean()) / (1e-8 + reward.std())
        with torch.no_grad():
            # get state values of batch
            values = critic_main(
                graph.x_c,
                graph.edge_index,
                graph.edge_features,
                graph.x_v,
                graph.x_c_batch,
                graph.x_v_batch,
                graph.candidates
            )

        values = torch.cat((values, torch.tensor([0], device='cpu')), dim=0)
        advantages, returns = compute_gae(reward, values, done)
        
        advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std())

    indices = np.random.permutation(len(dataset))
    shuffled_dataset = torch.utils.data.Subset(dataset, indices)

    dataloader2 = DataLoader(shuffled_dataset, batch_size=batch_size2, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
    batch_idx = 0
    # tau = 0.005
    total_actor_loss, total_critic_loss = 0, 0
    mean_entropy = []
    mean_kl_div = []

    for batch_i, SASR in enumerate(dataloader2):
        graph, action, reward, done, batch_old_log_probs = SASR
        
        batch_size_i = len(action)
        batch_start = batch_idx
        batch_end = batch_start + batch_size_i

        actor_opt.zero_grad()
        critic_opt.zero_grad()
    
        # formatting for indexing in batched data
        num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)]
        node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
        
        # Split candidates per graph and adjust indices
        candidate_counts = graph.nb_candidates.tolist()
        candidates_per_graph = []
        for i, num_candidates in enumerate(candidate_counts):
            start = sum(candidate_counts[:i])
            end = start + num_candidates
            candidates_i = graph.candidates[start:end] - node_offsets[i]
            candidates_per_graph.append(candidates_i)

        # Estimate policy gradient

        logits = actor(
            graph.x_c,
            graph.edge_index,
            graph.edge_features,
            graph.x_v
        )

        logits_per_graph = torch.split(logits, num_nodes_per_graph)
            
        log_probs = []
        entropy_arr = []
        for i in range(len(logits_per_graph)):
            logits_i = logits_per_graph[i]
            candidates_i = candidates_per_graph[i]
            action_i = action[i]
            
            # Create mask for valid actions
            mask_i = torch.zeros_like(logits_i, dtype=torch.bool)
            mask_i[candidates_i] = True
            
            # Create action distribution
            action_dist = MaskedCategorical(logits=logits_i, mask=mask_i)
            log_prob_i = action_dist.log_prob(action_i)
            log_probs.append(log_prob_i)
            entropy_arr.append(action_dist.entropy())

        log_probs = torch.stack(log_probs)

        # Find entropy of policy and store
        ent = torch.stack(entropy_arr).mean()
        mean_entropy.append(ent.detach().item() * len(entropy_arr))

        # Find KL divergence between old policy and updated 
        approx_kl = (torch.exp(batch_old_log_probs)*(batch_old_log_probs - log_probs.detach())).sum().item()
        mean_kl_div.append(approx_kl)

        batch_advantages = advantages[shuffled_dataset.indices[batch_start:batch_end]]

        eps = 0.3
        # ratios = torch.exp(log_probs - batch_old_log_probs)
        # surr1 = ratios * batch_advantages
        # surr2 = torch.clamp(ratios, 1 - eps, 1 + eps)*batch_advantages
        # actor_loss = -torch.min(surr1, surr2).mean() - decayed_learning_rate(step)*ent
        actor_loss = -torch.mean(log_probs*batch_advantages) - decayed_learning_rate(step)*ent

        # actor_loss = -torch.mean(log_probs * advantages[batch_start:batch_end]) 
        # actor_loss = -torch.sum(log_probs * advantages[shuffled_dataset.indices[batch_start:batch_end]])
        
        # get state values of batch
        values = critic_main(
            graph.x_c,
            graph.edge_index,
            graph.edge_features,
            graph.x_v,
            graph.x_c_batch,
            graph.x_v_batch,
            graph.candidates        
        )
            
        # critic_loss = F.mse_loss(values, returns[batch_start:batch_end])
        critic_loss = F.mse_loss(values, returns[shuffled_dataset.indices[batch_start:batch_end]])

        actor_loss.backward()
        critic_loss.backward()

        actor_opt.step()
        critic_opt.step()

        total_critic_loss += critic_loss.item()
        total_actor_loss += actor_loss.item()

        # for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()):
        #     target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data)

        batch_idx += batch_size_i

    return total_actor_loss/len(dataloader2), total_critic_loss/len(dataloader2), advantages.mean().item(), sum(mean_entropy)/batch_size1, sum(mean_kl_div)/batch_size1

def REINFORCE(actor: torch.nn.Module, optimizer: torch.optim.Optimizer, replay_buffer):
    actor.train()

    transitions = replay_buffer.memory
    batch_size = 64
    
    batch = Transition(*zip(*transitions))
    dataset = BipartiteDatasetRL(*batch)

    rewards = torch.tensor(batch.reward)
    returns = compute_discounted_returns(rewards)
    
    indices = np.random.permutation(len(dataset))
    shuffled_dataset = torch.utils.data.Subset(dataset, indices)

    dataloader = DataLoader(shuffled_dataset, batch_size=batch_size, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
    batch_idx = 0
    # tau = 0.005
    total_actor_loss, total_critic_loss = 0, 0
    mean_entropy = []
    mean_kl_div = []
    mean_advantages = []

    for batch_i, SASR in enumerate(dataloader):
        graph, action, reward, done, batch_old_log_probs = SASR
        if len(action) <= 1:
            continue
        
        batch_size_i = len(action)
        batch_start = batch_idx
        batch_end = batch_start + batch_size_i

        actor_opt.zero_grad()
        critic_opt.zero_grad()
    
        # formatting for indexing in batched data
        num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)]
        node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
        
        # Split candidates per graph and adjust indices
        candidate_counts = graph.nb_candidates.tolist()
        candidates_per_graph = []
        for i, num_candidates in enumerate(candidate_counts):
            start = sum(candidate_counts[:i])
            end = start + num_candidates
            candidates_i = graph.candidates[start:end] - node_offsets[i]
            candidates_per_graph.append(candidates_i)

        # Estimate policy gradient

        logits = actor(
            graph.x_c,
            graph.edge_index,
            graph.edge_features,
            graph.x_v
        )

        logits_per_graph = torch.split(logits, num_nodes_per_graph)
            
        log_probs = []
        entropy_arr = []
        for i in range(len(logits_per_graph)):
            logits_i = logits_per_graph[i]
            candidates_i = candidates_per_graph[i]
            action_i = action[i]
            
            # Create mask for valid actions
            mask_i = torch.zeros_like(logits_i, dtype=torch.bool)
            mask_i[candidates_i] = True
            
            # Create action distribution
            action_dist = MaskedCategorical(logits=logits_i, mask=mask_i)
            log_prob_i = action_dist.log_prob(action_i)
            log_probs.append(log_prob_i)
            entropy_arr.append(action_dist.entropy())

        log_probs = torch.stack(log_probs)

        # Find entropy of policy and store
        ent = torch.stack(entropy_arr).mean()
        mean_entropy.append(ent.detach().item() * len(entropy_arr))

        # Find KL divergence between old policy and updated 
        approx_kl = (torch.exp(batch_old_log_probs)*(batch_old_log_probs - log_probs.detach())).sum().item()
        mean_kl_div.append(approx_kl)

        batch_returns = returns[shuffled_dataset.indices[batch_start:batch_end]]

        eps = 0.3
        # ratios = torch.exp(log_probs - batch_old_log_probs)
        # surr1 = ratios * batch_advantages
        # surr2 = torch.clamp(ratios, 1 - eps, 1 + eps)*batch_advantages
        # actor_loss = -torch.min(surr1, surr2).mean() - decayed_learning_rate(step)*ent

        # actor_loss = -torch.mean(log_probs * advantages[batch_start:batch_end]) 
        # actor_loss = -torch.sum(log_probs * advantages[shuffled_dataset.indices[batch_start:batch_end]])
        
        # get state values of batch
        values = critic_main(
            graph.x_c,
            graph.edge_index,
            graph.edge_features,
            graph.x_v,
            graph.x_c_batch,
            graph.x_v_batch,
            graph.candidates        
        )

        batch_advantages = (batch_returns - values.detach())
        batch_advantages = (batch_advantages - batch_advantages.mean()) / (1e-8 + batch_advantages.std())
        mean_advantages.append(batch_advantages.sum().item())

        actor_loss = -torch.mean(log_probs*batch_advantages)

            
        # critic_loss = F.mse_loss(values, returns[batch_start:batch_end])
        critic_loss = F.mse_loss(values, batch_returns)

        actor_loss.backward()
        critic_loss.backward()

        actor_opt.step()
        critic_opt.step()

        total_critic_loss += critic_loss.item()
        total_actor_loss += actor_loss.item()

        # for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()):
        #     target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data)

        batch_idx += batch_size_i

    return total_actor_loss/len(dataloader), total_critic_loss/len(dataloader), np.sum(mean_advantages)/(len(transitions)), sum(mean_entropy)/batch_size, sum(mean_kl_div)/batch_size


def off_PAC(actor, critic, target_critic, actor_optimizer, critic_optimizer, buffer, tau=0.01):
    batch = Transition(*zip(*buffer.memory))
    dataset = GraphDataset(*batch)
    dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, follow_batch=['constraint_features', 'variable_features', 'candidates'])
    
    replay = next(iter(dataloader))
    states, log_probs, rewards, dones = replay
    
    with torch.no_grad():
        target_values = target_critic(
            states.constraint_features,
            states.edge_index, 
            states.edge_attr,
            states.variable_features,
            states.variable_features_batch,
        )
        # Compute Advantages
        advantages, returns_to_go = compute_gae(rewards, target_values.detach(), dones)
        advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std())
        # returns_to_go = (returns_to_go - returns_to_go.mean()) / (1e-8 + returns_to_go.std())

        # formatting for indexing in batched data
        num_nodes_per_graph = states.num_graphs*[int(states.variable_features.shape[0]/states.num_graphs)]
        node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
        
        # Split candidates per graph and adjust indices
        candidate_counts = states.nb_candidates.tolist()
        candidates_per_graph = []
        for i, num_candidates in enumerate(candidate_counts):
            start = sum(candidate_counts[:i])
            end = start + num_candidates
            candidates_i = states.candidates[start:end] - node_offsets[i]
            candidates_per_graph.append(candidates_i)

    learn_logits = actor(
        states.constraint_features,
        states.edge_index,
        states.edge_attr,
        states.variable_features,
    )

    learn_logits = torch.stack(torch.split(learn_logits, num_nodes_per_graph))

    learn_log_probs = []
    for branch in range(len(candidates_per_graph)):
        learn_logits_i = learn_logits[branch, :]
        candidates_i = candidates_per_graph[branch]
        learn_logits_i = learn_logits_i[candidates_i]        
        learn_dist = D.Categorical(logits=learn_logits_i)
        action_i = states.candidate_choices[branch]
        learn_log_probs.append(learn_dist.log_prob(action_i))


    learn_log_probs = torch.stack(learn_log_probs)

    # Compute importance sampling weights
    rho = torch.exp(learn_log_probs - log_probs)
    eps_clip = 0.25
    surr1 = rho * advantages
    surr2 = torch.clamp(rho, 1 - eps_clip, 1 + eps_clip) * advantages
    actor_loss = -torch.min(surr1, surr2).mean()

    # Compute critic loss
    values = critic(
        states.constraint_features,
        states.edge_index, 
        states.edge_attr,
        states.variable_features,
        states.variable_features_batch,
    )


    # critic_loss = F.mse_loss(values, returns_to_go)
    critic_loss = F.huber_loss(values, returns_to_go)

    # Compute actor loss
    # actor_loss = -torch.mean(rho*learn_log_probs*advantages)

    actor_optimizer.zero_grad()
    critic_optimizer.zero_grad()
    
    # Update actor and critic jointly
    actor_loss.backward()
    critic_loss.backward()
    # torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=1.0)

    actor_optimizer.step()
    critic_optimizer.step()

    for target_param, param in zip(target_critic.parameters(), critic.parameters()):
        target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)


    return actor_loss.item(), critic_loss.item()