EECE571F-project / RL_ecole / ecole_tutorial.py
ecole_tutorial.py
Raw
# The term 'brancher' is used to describe the RL agent or agent in general.
# Node selection policy is done within B&B; we have no control over this.
from random_agent import RandomAgent
import ecole

from ecole.environment import Branching
from ecole.instance import SetCoverGenerator
import numpy as np
import torch
from replay import Replay

def get_features(observation, action_set):
    constraint_features = torch.FloatTensor(observation.row_features)
    edge_indices = torch.LongTensor(observation.edge_features.indices.astype(np.int32))
    edge_features = torch.FloatTensor(np.expand_dims(observation.edge_features.values, axis=-1))
    variable_features = torch.FloatTensor(observation.variable_features)
    candidates = torch.LongTensor(np.array(action_set, dtype=np.int32))
    return constraint_features, edge_indices, edge_features, variable_features, candidates


if __name__ == "__main__":
    seed = 42 # random seed
    n_episodes = 100
    gamma = 0.95
    

    agent = RandomAgent() # random agent making decisions

    # parameters for branch and bound
    scip_parameters = {
    "separating/maxrounds": 0,
    "presolving/maxrestarts": 0,
    "limits/time": 3600,
    }

    env = Branching(scip_params=scip_parameters) # Branch and Bound environment
    env.seed(seed) 

    gen = SetCoverGenerator(n_rows=10, n_cols=20, density=0.2) # problem instance generator
    gen.seed(seed)

    for _ in range(n_episodes):
        replay = Replay(3600) # Replay buffer for RL agent
        problem_instance = next(gen) # get problem instance
        observation, action_set, reward_offset, done, info = env.reset(problem_instance) # reset environment
        constraint_features, edge_indices, edge_features, variable_features = get_features(observation) # convert to tensor

        while not done:
            action, _, old_log_probs = agent.action_select(action_set, observation) # agent picks and action to take given the current state
            next_observation, action_set, reward, done, info = env.step(action) # transition to next state

            # save to replay buffer
            if replay is not None:
                replay.push(observation, action, next_observation, reward, old_log_probs, gamma)
            
            observation = next_observation
        
        # last state we need gamma = 0
        if replay is not None:
            if len(replay) > 0:
                last_transition = replay.pop()
                replay.push(*(list(last_transition[:-1]) + [0.0]))