EECE571F-project / learn2branch_gasse / actor_critic.py
actor_critic.py
Raw
import ecole
import ecole.environment
import ecole.instance
import ecole.observation
import ecole.reward
import ecole.typing
import hydra
from omegaconf import DictConfig


from data import BipartiteDatasetRL
from torch_geometric.loader import DataLoader
from model import GNNActor, GNNCritic
import random

import torch
import torch.nn.functional as F
import numpy as np
import os

from replay import Replay, Transition

from torchrl.modules import MaskedCategorical
from utils import create_folder, to_state, to_tensor
import matplotlib.pyplot as plt

def compute_discounted_returns(rewards, gamma):
    rewards = rewards.float()
    discounts = gamma ** torch.arange(len(rewards)).float()
    discounted_rewards = rewards * discounts
    discounted_returns = torch.flip(torch.cumsum(torch.flip(discounted_rewards, [0]), dim=0), [0])
    return discounted_returns / discounts

## DOES NOT TRAIN WELL - needs fixing ##
def train_actor_critic(cfg: DictConfig, actor: torch.nn.Module, critic: torch.nn.Module, actor_opt: torch.optim.Optimizer,  critic_opt: torch.optim.Optimizer, replay_buffer):
    transitions = replay_buffer.memory
    nt = len(transitions)
    
    batch_size = 128
    # Gather transition information into tensors
    batch = Transition(*zip(*transitions))

    discounted_return_batch = compute_discounted_returns(torch.stack(batch.reward), cfg.training.gamma)
    
    dataset = BipartiteDatasetRL(batch.state, discounted_return_batch, batch.action_set, batch.action)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    actor.train()
    critic.train()

    total_actor_loss = 0
    total_critic_loss = 0
    for sub_batch in dataloader:
        
        # Need to fix rewards 
        graph, discounted_return, action = sub_batch
        
        logits = actor(
            graph.x_c,
            graph.edge_index,
            graph.edge_features,
            graph.x_v
        )

        # batch_indices = graph.batch  # shape [num_nodes_in_batch], indicates graph indices
        # num_nodes_per_graph = torch.bincount(batch_indices).tolist()
        num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)]
        logits_per_graph = torch.split(logits, num_nodes_per_graph)
        
        # Node offsets to adjust indices
        node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
        
        # Split candidates per graph and adjust indices
        candidate_counts = graph.nb_candidates.tolist()
        candidates_per_graph = []
        for i, num_candidates in enumerate(candidate_counts):
            start = sum(candidate_counts[:i])
            end = start + num_candidates
            candidates_i = graph.candidates[start:end] - node_offsets[i]
            candidates_per_graph.append(candidates_i)
        
        log_probs = []
        for i in range(len(logits_per_graph)):
            logits_i = logits_per_graph[i]
            candidates_i = candidates_per_graph[i]
            action_i = action[i]
            
            # Create mask for valid actions
            mask_i = torch.zeros_like(logits_i, dtype=torch.bool)
            mask_i[candidates_i] = True
            
            # Create action distribution
            action_dist = MaskedCategorical(logits=logits_i, mask=mask_i)
            log_prob_i = action_dist.log_prob(action_i)
            log_probs.append(log_prob_i)

        log_probs = torch.stack(log_probs)

        # state_values = critic(
        #     graph.x_c,
        #     graph.edge_index,
        #     graph.edge_features,
        #     graph.x_v
        # )
       
        # state_values_per_graph = torch.split(state_values, num_nodes_per_graph)
        # outputs = []
        # for i in range(len(state_values_per_graph)):
        #     state_values_i = state_values_per_graph[i]
        #     candidates_i = candidates_per_graph[i]
        #     candidate_state_values_i = state_values_i[candidates_i]
        #     output_i = candidate_state_values_i.mean()
        #     outputs.append(output_i)
        
        # output = torch.stack(outputs)
        
        # advantage = discounted_return - output.detach()
        advantage = discounted_return
        if advantage.numel() != 1:
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

        # entropy = -torch.sum(torch.exp(log_probs)*log_probs)
        actor_loss = -torch.mean(log_probs * advantage) #- 0.01*entropy
        
        actor_opt.zero_grad()
        actor_loss.backward()
        actor_opt.step()

        # critic_loss = torch.nn.functional.mse_loss(output, discounted_return)

        # critic_opt.zero_grad()
        # critic_loss.backward()
        # critic_opt.step()

        # total_critic_loss += critic_loss.item()
        total_actor_loss += actor_loss.item()

    return total_actor_loss/len(dataloader), 0.0
    # return total_actor_loss/len(dataloader), total_critic_loss/len(dataloader)

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    advantages = torch.zeros_like(rewards)
    returns_to_go = torch.zeros_like(rewards)

    last_gae = 0  # GAE for the final timestep in an episode
    last_return = 0 # return

    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] # td-error
        advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae # gae advantage

        returns_to_go[t] = last_return = rewards[t] + gamma*last_return

    return advantages, returns_to_go


def train_GAE(cfg: DictConfig, actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer,  critic_opt: torch.optim.Optimizer, replay_buffer):
    actor.train()
    critic_main.train()

    
    transitions = replay_buffer.memory
    batch_size1 = len(transitions)
    batch_size2 = 16
    batch = Transition(*zip(*transitions))
    dataset = BipartiteDatasetRL(*batch)
    dataloader1 = DataLoader(dataset, batch_size=batch_size1, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
    for SASR in dataloader1:
        graph, action, reward, done = SASR
        
        with torch.no_grad():
            # get state values of batch
            values = critic_target(
                graph.x_c,
                graph.edge_index,
                graph.edge_features,
                graph.x_v,
                graph.x_c_batch,
                graph.x_v_batch,
                graph.candidates
            )

        values = torch.cat((values, torch.tensor([0])), dim=0)
        advantages, returns = compute_gae(reward, values, done)
        if len(advantages) <= 1:
            return torch.nan, torch.nan
        # advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std())

    dataloader2 = DataLoader(dataset, batch_size=batch_size2, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])

    batch_idx = 0
    tau = 0.005
    total_actor_loss, total_critic_loss = 0, 0
    for SASR in dataloader2:
        graph, action, reward, done = SASR

        batch_size_i = len(action)
        batch_start = batch_idx
        batch_end = batch_start + batch_size_i

        actor_opt.zero_grad()
        critic_opt.zero_grad()
    
        # formatting for indexing in batched data
        num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)]
        node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
        
        # Split candidates per graph and adjust indices
        candidate_counts = graph.nb_candidates.tolist()
        candidates_per_graph = []
        for i, num_candidates in enumerate(candidate_counts):
            start = sum(candidate_counts[:i])
            end = start + num_candidates
            candidates_i = graph.candidates[start:end] - node_offsets[i]
            candidates_per_graph.append(candidates_i)

        # Estimate policy gradient

        logits = actor(
            graph.x_c,
            graph.edge_index,
            graph.edge_features,
            graph.x_v
        )

        logits_per_graph = torch.split(logits, num_nodes_per_graph)
            
        log_probs = []
        for i in range(len(logits_per_graph)):
            logits_i = logits_per_graph[i]
            candidates_i = candidates_per_graph[i]
            action_i = action[i]
            
            # Create mask for valid actions
            mask_i = torch.zeros_like(logits_i, dtype=torch.bool)
            mask_i[candidates_i] = True
            
            # Create action distribution
            action_dist = MaskedCategorical(logits=logits_i, mask=mask_i)
            log_prob_i = action_dist.log_prob(action_i)
            log_probs.append(log_prob_i)

        log_probs = torch.stack(log_probs)
        actor_loss = -torch.sum(log_probs * advantages[batch_start:batch_end]) 
        
        # get state values of batch
        values = critic_main(
            graph.x_c,
            graph.edge_index,
            graph.edge_features,
            graph.x_v,
            graph.x_c_batch,
            graph.x_v_batch,
            graph.candidates        
        )
            
        critic_loss = torch.nn.functional.huber_loss(values, returns[batch_start:batch_end])


        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(actor.parameters(), max_norm=1.0)

        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(critic_main.parameters(), max_norm=1.0)

        actor_opt.step()
        critic_opt.step()

        total_critic_loss += critic_loss.item()
        total_actor_loss += actor_loss.item()

        for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()):
            target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data)

        batch_idx += batch_size_i


    return total_actor_loss/len(dataloader2), total_critic_loss/len(dataloader2)


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg : DictConfig) -> None:
    # print(OmegaConf.to_yaml(cfg))
    torch.manual_seed(cfg.training.seed)
    random.seed(cfg.training.seed)
    np.random.seed(cfg.training.seed)
    generator = torch.Generator()
    generator.manual_seed(cfg.training.seed)

    # define problem instance generator
    if cfg.problem == "set_cover": 
        gen = ecole.instance.SetCoverGenerator(
            n_rows=cfg.n_rows,
            n_cols=cfg.n_cols,
            density=cfg.density,
        )
    gen.seed(cfg.training.seed)
    
    # define the observation function
    observation_functions = (
        ecole.observation.NodeBipartite()
    )

    # scip parameters used in paper
    scip_parameters = {
        "separating/maxrounds": 0,
        "presolving/maxrestarts": 0,
        "limits/time": 900, # 15 minute time limit to run,
        "limits/nodes": 1000,
        "lp/threads": 4
    }
    
    # define the environment
    env = ecole.environment.Branching(
        observation_function=observation_functions,
        reward_function=-1*ecole.reward.DualIntegral(),
        scip_params=scip_parameters,
        information_function=None
    )

    env.seed(cfg.training.seed) # seed environment


    # define actor network
    actor = GNNActor()
    critic_main = GNNCritic()
    critic_target = GNNCritic()


    def initialize_weights(m):
        if type(m) == torch.nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                m.bias.data.fill_(0.01)
    
    actor.apply(initialize_weights)
    critic_main.apply(initialize_weights)
    critic_target.apply(initialize_weights)

    # define an optimizer
    actor_opt = torch.optim.Adam(actor.parameters(), lr=cfg.training.lr)
    critic_opt = torch.optim.Adam(critic_main.parameters(), lr=cfg.training.lr)

    fig, (ax1, ax2) = plt.subplots(1, 2)
    counts = []
    actor_loss_arr = []
    critic_loss_arr = []
    for episode_i in range(cfg.training.n_episodes):
        instance = next(gen)

        actor_loss_epoch, critic_loss_epoch = 0, 0
        valid_epochs = 5

        for n in range(5):
            observation, action_set, _, done, info = env.reset(instance)
            if done:
                break

            replay_buffer = Replay(cfg.max_buffer_size)
            expected_return = 0

            while not done:
                # Convert np ecole observation to tensor
                state_tensor = to_state(observation, cfg.device)
                action_set_tensor = torch.tensor(action_set, dtype=torch.int32)
                m = state_tensor[0].shape[0]
                n = state_tensor[-1].shape[0]
                complexity = (n + m)/max(n, m)

                # pass state to actor and get distribution and create valid action mask
                logits = actor(*state_tensor)
                mask = torch.zeros_like(logits, dtype=torch.bool)
                mask[action_set] = True
                
                # sample from action distribution
                action_dist = MaskedCategorical(logits=logits, mask=mask)
                action = action_dist.sample()

                # take action, and go to the next state
                next_observation, next_action_set, reward, done, _ = env.step(action.item())
                reward = reward / complexity
                expected_return += reward
                reward_tensor = to_tensor(reward)
                done_tensor = torch.tensor(done, dtype=torch.int32)

                # record in replay buffer
                replay_buffer.push(
                    state_tensor,       # current state
                    action,             # action taken (store as tensor)
                    reward_tensor,      # reward received
                    action_set_tensor,  # action set,
                    done_tensor         # mark when episode is finished
                )

                if done and len(replay_buffer) == 1:
                    valid_epochs -= 1

                # Update current observation
                observation = next_observation
                action_set = next_action_set
                if done and len(replay_buffer) > 1:

                    # Train actor and critic networks using the replay buffer
                    actor_loss, critic_loss = train_GAE(cfg, actor, critic_main, critic_target, actor_opt, critic_opt, replay_buffer)
                    print(f"episode: {episode_i}, actor loss: {actor_loss:>.4f}, critic loss: {critic_loss:>.4f}, return: {expected_return}")
        
                    actor_loss_arr.append(actor_loss)
                    critic_loss_arr.append(critic_loss)
        ax1.clear()
        ax2.clear()
        ax1.plot(actor_loss_arr, label=f'Actor Loss')
        ax1.legend()
        ax1.grid()
        ax2.plot(critic_loss_arr, label=f'Critic Loss')
        ax2.legend()
        ax2.grid()
        plt.pause(0.5)
    plt.show()
    path = os.path.join(os.getcwd(), "models")
    name = "learn2branch-set_cover-actor_critic-gae.pt"
    create_folder(path)
    torch.save(actor.state_dict(), os.path.join(path, name))



if __name__ == "__main__":
    main()