import ecole
import ecole.environment
import ecole.instance
import ecole.observation
import ecole.reward
import ecole.typing
import hydra
from omegaconf import DictConfig
from data import BipartiteDatasetRL
from torch_geometric.loader import DataLoader
from model import GNNActor, GNNCritic
import random
import torch
import torch.nn.functional as F
import numpy as np
import os
from replay import Replay, Transition
from torchrl.modules import MaskedCategorical
from utils import create_folder, to_state, to_tensor
import matplotlib.pyplot as plt
def compute_discounted_returns(rewards, gamma):
rewards = rewards.float()
discounts = gamma ** torch.arange(len(rewards)).float()
discounted_rewards = rewards * discounts
discounted_returns = torch.flip(torch.cumsum(torch.flip(discounted_rewards, [0]), dim=0), [0])
return discounted_returns / discounts
## DOES NOT TRAIN WELL - needs fixing ##
def train_actor_critic(cfg: DictConfig, actor: torch.nn.Module, critic: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer):
transitions = replay_buffer.memory
nt = len(transitions)
batch_size = 128
# Gather transition information into tensors
batch = Transition(*zip(*transitions))
discounted_return_batch = compute_discounted_returns(torch.stack(batch.reward), cfg.training.gamma)
dataset = BipartiteDatasetRL(batch.state, discounted_return_batch, batch.action_set, batch.action)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
actor.train()
critic.train()
total_actor_loss = 0
total_critic_loss = 0
for sub_batch in dataloader:
# Need to fix rewards
graph, discounted_return, action = sub_batch
logits = actor(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v
)
# batch_indices = graph.batch # shape [num_nodes_in_batch], indicates graph indices
# num_nodes_per_graph = torch.bincount(batch_indices).tolist()
num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)]
logits_per_graph = torch.split(logits, num_nodes_per_graph)
# Node offsets to adjust indices
node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
# Split candidates per graph and adjust indices
candidate_counts = graph.nb_candidates.tolist()
candidates_per_graph = []
for i, num_candidates in enumerate(candidate_counts):
start = sum(candidate_counts[:i])
end = start + num_candidates
candidates_i = graph.candidates[start:end] - node_offsets[i]
candidates_per_graph.append(candidates_i)
log_probs = []
for i in range(len(logits_per_graph)):
logits_i = logits_per_graph[i]
candidates_i = candidates_per_graph[i]
action_i = action[i]
# Create mask for valid actions
mask_i = torch.zeros_like(logits_i, dtype=torch.bool)
mask_i[candidates_i] = True
# Create action distribution
action_dist = MaskedCategorical(logits=logits_i, mask=mask_i)
log_prob_i = action_dist.log_prob(action_i)
log_probs.append(log_prob_i)
log_probs = torch.stack(log_probs)
# state_values = critic(
# graph.x_c,
# graph.edge_index,
# graph.edge_features,
# graph.x_v
# )
# state_values_per_graph = torch.split(state_values, num_nodes_per_graph)
# outputs = []
# for i in range(len(state_values_per_graph)):
# state_values_i = state_values_per_graph[i]
# candidates_i = candidates_per_graph[i]
# candidate_state_values_i = state_values_i[candidates_i]
# output_i = candidate_state_values_i.mean()
# outputs.append(output_i)
# output = torch.stack(outputs)
# advantage = discounted_return - output.detach()
advantage = discounted_return
if advantage.numel() != 1:
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
# entropy = -torch.sum(torch.exp(log_probs)*log_probs)
actor_loss = -torch.mean(log_probs * advantage) #- 0.01*entropy
actor_opt.zero_grad()
actor_loss.backward()
actor_opt.step()
# critic_loss = torch.nn.functional.mse_loss(output, discounted_return)
# critic_opt.zero_grad()
# critic_loss.backward()
# critic_opt.step()
# total_critic_loss += critic_loss.item()
total_actor_loss += actor_loss.item()
return total_actor_loss/len(dataloader), 0.0
# return total_actor_loss/len(dataloader), total_critic_loss/len(dataloader)
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
advantages = torch.zeros_like(rewards)
returns_to_go = torch.zeros_like(rewards)
last_gae = 0 # GAE for the final timestep in an episode
last_return = 0 # return
for t in reversed(range(len(rewards))):
delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] # td-error
advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae # gae advantage
returns_to_go[t] = last_return = rewards[t] + gamma*last_return
return advantages, returns_to_go
def train_GAE(cfg: DictConfig, actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer):
actor.train()
critic_main.train()
transitions = replay_buffer.memory
batch_size1 = len(transitions)
batch_size2 = 16
batch = Transition(*zip(*transitions))
dataset = BipartiteDatasetRL(*batch)
dataloader1 = DataLoader(dataset, batch_size=batch_size1, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
for SASR in dataloader1:
graph, action, reward, done = SASR
with torch.no_grad():
# get state values of batch
values = critic_target(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v,
graph.x_c_batch,
graph.x_v_batch,
graph.candidates
)
values = torch.cat((values, torch.tensor([0])), dim=0)
advantages, returns = compute_gae(reward, values, done)
if len(advantages) <= 1:
return torch.nan, torch.nan
# advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std())
dataloader2 = DataLoader(dataset, batch_size=batch_size2, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
batch_idx = 0
tau = 0.005
total_actor_loss, total_critic_loss = 0, 0
for SASR in dataloader2:
graph, action, reward, done = SASR
batch_size_i = len(action)
batch_start = batch_idx
batch_end = batch_start + batch_size_i
actor_opt.zero_grad()
critic_opt.zero_grad()
# formatting for indexing in batched data
num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)]
node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
# Split candidates per graph and adjust indices
candidate_counts = graph.nb_candidates.tolist()
candidates_per_graph = []
for i, num_candidates in enumerate(candidate_counts):
start = sum(candidate_counts[:i])
end = start + num_candidates
candidates_i = graph.candidates[start:end] - node_offsets[i]
candidates_per_graph.append(candidates_i)
# Estimate policy gradient
logits = actor(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v
)
logits_per_graph = torch.split(logits, num_nodes_per_graph)
log_probs = []
for i in range(len(logits_per_graph)):
logits_i = logits_per_graph[i]
candidates_i = candidates_per_graph[i]
action_i = action[i]
# Create mask for valid actions
mask_i = torch.zeros_like(logits_i, dtype=torch.bool)
mask_i[candidates_i] = True
# Create action distribution
action_dist = MaskedCategorical(logits=logits_i, mask=mask_i)
log_prob_i = action_dist.log_prob(action_i)
log_probs.append(log_prob_i)
log_probs = torch.stack(log_probs)
actor_loss = -torch.sum(log_probs * advantages[batch_start:batch_end])
# get state values of batch
values = critic_main(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v,
graph.x_c_batch,
graph.x_v_batch,
graph.candidates
)
critic_loss = torch.nn.functional.huber_loss(values, returns[batch_start:batch_end])
actor_loss.backward()
torch.nn.utils.clip_grad_norm_(actor.parameters(), max_norm=1.0)
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(critic_main.parameters(), max_norm=1.0)
actor_opt.step()
critic_opt.step()
total_critic_loss += critic_loss.item()
total_actor_loss += actor_loss.item()
for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()):
target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data)
batch_idx += batch_size_i
return total_actor_loss/len(dataloader2), total_critic_loss/len(dataloader2)
@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg : DictConfig) -> None:
# print(OmegaConf.to_yaml(cfg))
torch.manual_seed(cfg.training.seed)
random.seed(cfg.training.seed)
np.random.seed(cfg.training.seed)
generator = torch.Generator()
generator.manual_seed(cfg.training.seed)
# define problem instance generator
if cfg.problem == "set_cover":
gen = ecole.instance.SetCoverGenerator(
n_rows=cfg.n_rows,
n_cols=cfg.n_cols,
density=cfg.density,
)
gen.seed(cfg.training.seed)
# define the observation function
observation_functions = (
ecole.observation.NodeBipartite()
)
# scip parameters used in paper
scip_parameters = {
"separating/maxrounds": 0,
"presolving/maxrestarts": 0,
"limits/time": 900, # 15 minute time limit to run,
"limits/nodes": 1000,
"lp/threads": 4
}
# define the environment
env = ecole.environment.Branching(
observation_function=observation_functions,
reward_function=-1*ecole.reward.DualIntegral(),
scip_params=scip_parameters,
information_function=None
)
env.seed(cfg.training.seed) # seed environment
# define actor network
actor = GNNActor()
critic_main = GNNCritic()
critic_target = GNNCritic()
def initialize_weights(m):
if type(m) == torch.nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.01)
actor.apply(initialize_weights)
critic_main.apply(initialize_weights)
critic_target.apply(initialize_weights)
# define an optimizer
actor_opt = torch.optim.Adam(actor.parameters(), lr=cfg.training.lr)
critic_opt = torch.optim.Adam(critic_main.parameters(), lr=cfg.training.lr)
fig, (ax1, ax2) = plt.subplots(1, 2)
counts = []
actor_loss_arr = []
critic_loss_arr = []
for episode_i in range(cfg.training.n_episodes):
instance = next(gen)
actor_loss_epoch, critic_loss_epoch = 0, 0
valid_epochs = 5
for n in range(5):
observation, action_set, _, done, info = env.reset(instance)
if done:
break
replay_buffer = Replay(cfg.max_buffer_size)
expected_return = 0
while not done:
# Convert np ecole observation to tensor
state_tensor = to_state(observation, cfg.device)
action_set_tensor = torch.tensor(action_set, dtype=torch.int32)
m = state_tensor[0].shape[0]
n = state_tensor[-1].shape[0]
complexity = (n + m)/max(n, m)
# pass state to actor and get distribution and create valid action mask
logits = actor(*state_tensor)
mask = torch.zeros_like(logits, dtype=torch.bool)
mask[action_set] = True
# sample from action distribution
action_dist = MaskedCategorical(logits=logits, mask=mask)
action = action_dist.sample()
# take action, and go to the next state
next_observation, next_action_set, reward, done, _ = env.step(action.item())
reward = reward / complexity
expected_return += reward
reward_tensor = to_tensor(reward)
done_tensor = torch.tensor(done, dtype=torch.int32)
# record in replay buffer
replay_buffer.push(
state_tensor, # current state
action, # action taken (store as tensor)
reward_tensor, # reward received
action_set_tensor, # action set,
done_tensor # mark when episode is finished
)
if done and len(replay_buffer) == 1:
valid_epochs -= 1
# Update current observation
observation = next_observation
action_set = next_action_set
if done and len(replay_buffer) > 1:
# Train actor and critic networks using the replay buffer
actor_loss, critic_loss = train_GAE(cfg, actor, critic_main, critic_target, actor_opt, critic_opt, replay_buffer)
print(f"episode: {episode_i}, actor loss: {actor_loss:>.4f}, critic loss: {critic_loss:>.4f}, return: {expected_return}")
actor_loss_arr.append(actor_loss)
critic_loss_arr.append(critic_loss)
ax1.clear()
ax2.clear()
ax1.plot(actor_loss_arr, label=f'Actor Loss')
ax1.legend()
ax1.grid()
ax2.plot(critic_loss_arr, label=f'Critic Loss')
ax2.legend()
ax2.grid()
plt.pause(0.5)
plt.show()
path = os.path.join(os.getcwd(), "models")
name = "learn2branch-set_cover-actor_critic-gae.pt"
create_folder(path)
torch.save(actor.state_dict(), os.path.join(path, name))
if __name__ == "__main__":
main()