# The term 'brancher' is used to describe the RL agent or agent in general. # Node selection policy is done within B&B; we have no control over this. from random_agent import RandomAgent import ecole from ecole.environment import Branching from ecole.instance import SetCoverGenerator import numpy as np import torch from replay import Replay def get_features(observation, action_set): constraint_features = torch.FloatTensor(observation.row_features) edge_indices = torch.LongTensor(observation.edge_features.indices.astype(np.int32)) edge_features = torch.FloatTensor(np.expand_dims(observation.edge_features.values, axis=-1)) variable_features = torch.FloatTensor(observation.variable_features) candidates = torch.LongTensor(np.array(action_set, dtype=np.int32)) return constraint_features, edge_indices, edge_features, variable_features, candidates if __name__ == "__main__": seed = 42 # random seed n_episodes = 100 gamma = 0.95 agent = RandomAgent() # random agent making decisions # parameters for branch and bound scip_parameters = { "separating/maxrounds": 0, "presolving/maxrestarts": 0, "limits/time": 3600, } env = Branching(scip_params=scip_parameters) # Branch and Bound environment env.seed(seed) gen = SetCoverGenerator(n_rows=10, n_cols=20, density=0.2) # problem instance generator gen.seed(seed) for _ in range(n_episodes): replay = Replay(3600) # Replay buffer for RL agent problem_instance = next(gen) # get problem instance observation, action_set, reward_offset, done, info = env.reset(problem_instance) # reset environment constraint_features, edge_indices, edge_features, variable_features = get_features(observation) # convert to tensor while not done: action, _, old_log_probs = agent.action_select(action_set, observation) # agent picks and action to take given the current state next_observation, action_set, reward, done, info = env.step(action) # transition to next state # save to replay buffer if replay is not None: replay.push(observation, action, next_observation, reward, old_log_probs, gamma) observation = next_observation # last state we need gamma = 0 if replay is not None: if len(replay) > 0: last_transition = replay.pop() replay.push(*(list(last_transition[:-1]) + [0.0]))