import numpy as np import gymnasium as gym from gymnasium import spaces from scipy.optimize import linprog import networkx as nx import matplotlib.pyplot as plt from networkx.drawing.nx_pydot import graphviz_layout class BranchAndBoundEnv(gym.Env): def __init__(self, c, A_ub, b_ub, A_eq, b_eq, bounds, integer_indices, render=False): super(BranchAndBoundEnv, self).__init__() self.c = c self.A_ub = A_ub self.b_ub = b_ub self.A_eq = A_eq self.b_eq = b_eq self.bounds = bounds self.dim_x = len(c) self.integer_indices = integer_indices self.render = render # self.constraint_dim = A_ub.shape[0] self.max_candidates = len(integer_indices) # Max number of candidate variables to branch on # Define action space (variables to branch on) self.action_space = spaces.Discrete(self.dim_x) # Define observation space (LP solution and action mask) self.observation_space = spaces.Dict({ 'observation': spaces.Box(low=-np.inf, high=np.inf, shape=(self.dim_x,)), 'action_mask': spaces.Box(low=0, high=1, shape=(self.dim_x,), dtype=np.int32) }) self.state = None self.action_mask = None if self.render: ## visualization tools ## # Search tree setup self.tree = nx.DiGraph() # Directed graph for the tree self.node_counter = 0 # Unique counter for node IDs self.parent_stack = [] # Keep track of parent nodes for branching def reset(self): self.state, self.action_mask = self._get_initial_state() self.best_obj_value = np.inf self.node_queue = [] return self._get_observation() def _get_initial_state(self): # Solve the initial LP relaxation res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=self.bounds) x_lp = res.x obj_val = res.fun # Identify fractional variables fractional_vars = [i.item() for i in self.integer_indices if not np.isclose(x_lp[i], np.round(x_lp[i]))] # Create action mask action_mask = np.zeros(self.dim_x, dtype=np.int32) action_mask[fractional_vars] = 1 # Only the first num_candidates are valid self.prev_obj_val = obj_val # Store state state = { 'x_lp': x_lp.copy(), 'obj_val': obj_val, 'fractional_vars': fractional_vars, 'bounds': self.bounds } if self.render: ## visualization tools ## # Add the root node to the tree self.tree.add_node(self.node_counter, label="Root", bounds=self.bounds, obj_val=obj_val) self.current_node_id = self.node_counter self.node_counter += 1 return state, action_mask def step(self, action): var_idx = self.state['fractional_vars'][action] if not self.action_mask[var_idx]: raise ValueError("Selected action is invalid.") # Get the variable index to branch on x_i = self.state['x_lp'][var_idx] floor_x_i = np.floor(x_i).item() ceil_x_i = np.ceil(x_i).item() # Copy current bounds old_bounds = self.state['bounds'] # Initialize a list to store valid child nodes valid_children = [] ## Left Child Node (x_i <= floor_x_i) ## left_bounds = old_bounds.copy() left_bounds[var_idx] = (left_bounds[var_idx][0], floor_x_i) # Solve LP for left child left_res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=left_bounds) if left_res.status == 0: # If left child is feasible left_obj_val = left_res.fun if left_obj_val < self.best_obj_value: # make sure that left child's lower bound is valid valid_children.append({ 'bounds': left_bounds, 'x_lp': left_res.x, 'obj_val': left_obj_val }) ## Right Child Node (x_i >= ceil_x_i) ## right_bounds = old_bounds.copy() right_bounds[var_idx] = (ceil_x_i, right_bounds[var_idx][1]) # Solve LP for right child right_res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=right_bounds) if right_res.status == 0: # If right child is feasible right_obj_val = right_res.fun if right_obj_val < self.best_obj_value: # make sure that lower bound is valid valid_children.append({ 'bounds': right_bounds, 'x_lp': right_res.x, 'obj_val': right_obj_val }) if len(valid_children) == 0: # No feasible children from branching on this variable # Remove the variable from the action mask self.action_mask[var_idx] = 0 self.state['fractional_vars'].remove(var_idx) if np.sum(self.action_mask) > 0: # Case: when more than one variable can be branched on, but the agent's first choice was poor. # There are still fractional variables to branch on in this node reward = -1 # Small penalty for choosing a dead-end variable done = False return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value} else: # Case: when there is only one variable to branch on but both branching solutions are infeasible # No more variables to branch on in this node if self.node_queue: # Transition to a node from the node queue self._pop_node_from_queue() reward = -1 # Small penalty for backtracking done = False return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value} else: # Node queue is empty, search finishes done = True reward = 0 return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value} else: if self.render: ## visualization tools ## # Add children to the tree and select a child node child_id = [] for child in valid_children: new_node_id = self.node_counter self.tree.add_node(new_node_id, label=f"Node {child['obj_val']}", bounds=child['bounds'], obj_val=child['obj_val']) self.tree.add_edge(self.current_node_id, new_node_id, label=child['bounds']) self.node_counter += 1 child_id.append(new_node_id) # (for now) randomly select a child node to branch selected_child = np.random.choice(valid_children) # Add other children to the node queue for id, child in enumerate(valid_children): if child != selected_child: self.node_queue.append(child) elif self.render: self.current_node_id = child_id[id] # State transition to the selected child node self.state['bounds'] = selected_child['bounds'] self.state['x_lp'] = selected_child['x_lp'] self.state['obj_val'] = selected_child['obj_val'] # Update fractional variables and action mask fractional_vars = [i for i in self.integer_indices if not np.isclose( self.state['x_lp'][i], np.round(self.state['x_lp'][i]), atol=1e-5)] self.state['fractional_vars'] = fractional_vars self.action_mask = np.zeros(self.dim_x, dtype=np.int32) self.action_mask[self.state['fractional_vars']] = 1 # Update the best known objective value if an integer solution is found if self._is_integer_solution(self.state['x_lp']): if self.state['obj_val'] < self.best_obj_value: self.best_obj_value = self.state['obj_val'] self.best_solution = self.state['x_lp'].copy() if self.node_queue: # TODO: edge case when the agent reaches a leaf node. # TODO: do we terminate? do we pop from candidate list? # Transition to a node from the node queue # self._pop_node_from_queue() done = True else: # Node queue is empty, search finishes done = True else: done = False # Get reward reward = self._get_reward() if self.render: self._visualize_search_tree() return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value} def _pop_node_from_queue(self): # Transition to the next node in the node queue next_node = self.node_queue.pop() self.state['bounds'] = next_node['bounds'] self.state['x_lp'] = next_node['x_lp'] self.state['obj_val'] = next_node['obj_val'] # Update fractional variables and action mask #TODO change to .item fractional_vars = [i for i in self.integer_indices if not np.isclose( self.state['x_lp'][i], np.round(self.state['x_lp'][i]), atol=1e-5)] self.state['fractional_vars'] = fractional_vars self.action_mask = np.zeros(self.dim_x, dtype=np.int32) self.action_mask[self.state['fractional_vars']] = 1 def _get_reward(self): reward = self.prev_obj_val - self.state['obj_val'] self.prev_obj_val = self.state['obj_val'] return reward def _is_integer_solution(self, x): # Check if all variables are integers within a tolerance return all(np.isclose(x[i], np.round(x[i]), atol=1e-5) for i in self.integer_indices) def _check_terminal(self): # Check if the current solution is integer feasible if self._is_integer_solution(self.state['x_lp']) and len(self.node_queue) == 0: print("Reason for completion: all required variables are integers and no nodes left in the queue") return True # Check if there are no valid actions left if np.sum(self.action_mask) == 0: print("Reason for completion: no valid actions left to take") return True if len(self.node_queue) == 0: print("Reason for completion: no nodes left in the queue") return True return False def _get_observation(self): obs = { 'observation': self.state['x_lp'], 'action_mask': self.action_mask, } return obs def _visualize_search_tree(self): plt.figure(figsize=(10, 8)) pos = graphviz_layout(self.tree, prog="dot") # Draw the graph with node labels node_labels = nx.get_node_attributes(self.tree, 'label') edge_labels = nx.get_edge_attributes(self.tree, 'label') nx.draw(self.tree, pos, with_labels=True, labels=node_labels, node_color='skyblue', node_size=1200, font_size=10, font_weight='bold', edge_color='gray', arrows=True) nx.draw_networkx_edge_labels(self.tree, pos, edge_labels=edge_labels, font_color='red', font_size=8) plt.title("Branch and Bound Search Tree") plt.show()