# The term 'brancher' is used to describe the RL agent or agent in general.
# Node selection policy is done within B&B; we have no control over this.
from random_agent import RandomAgent
import ecole
from ecole.environment import Branching
from ecole.instance import SetCoverGenerator
import numpy as np
import torch
from replay import Replay
def get_features(observation, action_set):
constraint_features = torch.FloatTensor(observation.row_features)
edge_indices = torch.LongTensor(observation.edge_features.indices.astype(np.int32))
edge_features = torch.FloatTensor(np.expand_dims(observation.edge_features.values, axis=-1))
variable_features = torch.FloatTensor(observation.variable_features)
candidates = torch.LongTensor(np.array(action_set, dtype=np.int32))
return constraint_features, edge_indices, edge_features, variable_features, candidates
if __name__ == "__main__":
seed = 42 # random seed
n_episodes = 100
gamma = 0.95
agent = RandomAgent() # random agent making decisions
# parameters for branch and bound
scip_parameters = {
"separating/maxrounds": 0,
"presolving/maxrestarts": 0,
"limits/time": 3600,
}
env = Branching(scip_params=scip_parameters) # Branch and Bound environment
env.seed(seed)
gen = SetCoverGenerator(n_rows=10, n_cols=20, density=0.2) # problem instance generator
gen.seed(seed)
for _ in range(n_episodes):
replay = Replay(3600) # Replay buffer for RL agent
problem_instance = next(gen) # get problem instance
observation, action_set, reward_offset, done, info = env.reset(problem_instance) # reset environment
constraint_features, edge_indices, edge_features, variable_features = get_features(observation) # convert to tensor
while not done:
action, _, old_log_probs = agent.action_select(action_set, observation) # agent picks and action to take given the current state
next_observation, action_set, reward, done, info = env.step(action) # transition to next state
# save to replay buffer
if replay is not None:
replay.push(observation, action, next_observation, reward, old_log_probs, gamma)
observation = next_observation
# last state we need gamma = 0
if replay is not None:
if len(replay) > 0:
last_transition = replay.pop()
replay.push(*(list(last_transition[:-1]) + [0.0]))