import numpy as np
import gymnasium as gym
from gymnasium import spaces
from scipy.optimize import linprog
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout
class BranchAndBoundEnv(gym.Env):
def __init__(self, c, A_ub, b_ub, A_eq, b_eq, bounds, integer_indices, render=False):
super(BranchAndBoundEnv, self).__init__()
self.c = c
self.A_ub = A_ub
self.b_ub = b_ub
self.A_eq = A_eq
self.b_eq = b_eq
self.bounds = bounds
self.dim_x = len(c)
self.integer_indices = integer_indices
self.render = render
# self.constraint_dim = A_ub.shape[0]
self.max_candidates = len(integer_indices) # Max number of candidate variables to branch on
# Define action space (variables to branch on)
self.action_space = spaces.Discrete(self.dim_x)
# Define observation space (LP solution and action mask)
self.observation_space = spaces.Dict({
'observation': spaces.Box(low=-np.inf, high=np.inf, shape=(self.dim_x,)),
'action_mask': spaces.Box(low=0, high=1, shape=(self.dim_x,), dtype=np.int32)
})
self.state = None
self.action_mask = None
if self.render:
## visualization tools ##
# Search tree setup
self.tree = nx.DiGraph() # Directed graph for the tree
self.node_counter = 0 # Unique counter for node IDs
self.parent_stack = [] # Keep track of parent nodes for branching
def reset(self):
self.state, self.action_mask = self._get_initial_state()
self.best_obj_value = np.inf
self.node_queue = []
return self._get_observation()
def _get_initial_state(self):
# Solve the initial LP relaxation
res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=self.bounds)
x_lp = res.x
obj_val = res.fun
# Identify fractional variables
fractional_vars = [i.item() for i in self.integer_indices if not np.isclose(x_lp[i], np.round(x_lp[i]))]
# Create action mask
action_mask = np.zeros(self.dim_x, dtype=np.int32)
action_mask[fractional_vars] = 1 # Only the first num_candidates are valid
self.prev_obj_val = obj_val
# Store state
state = {
'x_lp': x_lp.copy(),
'obj_val': obj_val,
'fractional_vars': fractional_vars,
'bounds': self.bounds
}
if self.render:
## visualization tools ##
# Add the root node to the tree
self.tree.add_node(self.node_counter, label="Root", bounds=self.bounds, obj_val=obj_val)
self.current_node_id = self.node_counter
self.node_counter += 1
return state, action_mask
def step(self, action):
var_idx = self.state['fractional_vars'][action]
if not self.action_mask[var_idx]:
raise ValueError("Selected action is invalid.")
# Get the variable index to branch on
x_i = self.state['x_lp'][var_idx]
floor_x_i = np.floor(x_i).item()
ceil_x_i = np.ceil(x_i).item()
# Copy current bounds
old_bounds = self.state['bounds']
# Initialize a list to store valid child nodes
valid_children = []
## Left Child Node (x_i <= floor_x_i) ##
left_bounds = old_bounds.copy()
left_bounds[var_idx] = (left_bounds[var_idx][0], floor_x_i)
# Solve LP for left child
left_res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=left_bounds)
if left_res.status == 0: # If left child is feasible
left_obj_val = left_res.fun
if left_obj_val < self.best_obj_value: # make sure that left child's lower bound is valid
valid_children.append({
'bounds': left_bounds,
'x_lp': left_res.x,
'obj_val': left_obj_val
})
## Right Child Node (x_i >= ceil_x_i) ##
right_bounds = old_bounds.copy()
right_bounds[var_idx] = (ceil_x_i, right_bounds[var_idx][1])
# Solve LP for right child
right_res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=right_bounds)
if right_res.status == 0: # If right child is feasible
right_obj_val = right_res.fun
if right_obj_val < self.best_obj_value: # make sure that lower bound is valid
valid_children.append({
'bounds': right_bounds,
'x_lp': right_res.x,
'obj_val': right_obj_val
})
if len(valid_children) == 0:
# No feasible children from branching on this variable
# Remove the variable from the action mask
self.action_mask[var_idx] = 0
self.state['fractional_vars'].remove(var_idx)
if np.sum(self.action_mask) > 0: # Case: when more than one variable can be branched on, but the agent's first choice was poor.
# There are still fractional variables to branch on in this node
reward = -1 # Small penalty for choosing a dead-end variable
done = False
return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value}
else: # Case: when there is only one variable to branch on but both branching solutions are infeasible
# No more variables to branch on in this node
if self.node_queue:
# Transition to a node from the node queue
self._pop_node_from_queue()
reward = -1 # Small penalty for backtracking
done = False
return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value}
else:
# Node queue is empty, search finishes
done = True
reward = 0
return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value}
else:
if self.render:
## visualization tools ##
# Add children to the tree and select a child node
child_id = []
for child in valid_children:
new_node_id = self.node_counter
self.tree.add_node(new_node_id, label=f"Node {child['obj_val']}", bounds=child['bounds'], obj_val=child['obj_val'])
self.tree.add_edge(self.current_node_id, new_node_id, label=child['bounds'])
self.node_counter += 1
child_id.append(new_node_id)
# (for now) randomly select a child node to branch
selected_child = np.random.choice(valid_children)
# Add other children to the node queue
for id, child in enumerate(valid_children):
if child != selected_child:
self.node_queue.append(child)
elif self.render:
self.current_node_id = child_id[id]
# State transition to the selected child node
self.state['bounds'] = selected_child['bounds']
self.state['x_lp'] = selected_child['x_lp']
self.state['obj_val'] = selected_child['obj_val']
# Update fractional variables and action mask
fractional_vars = [i for i in self.integer_indices if not np.isclose(
self.state['x_lp'][i], np.round(self.state['x_lp'][i]), atol=1e-5)]
self.state['fractional_vars'] = fractional_vars
self.action_mask = np.zeros(self.dim_x, dtype=np.int32)
self.action_mask[self.state['fractional_vars']] = 1
# Update the best known objective value if an integer solution is found
if self._is_integer_solution(self.state['x_lp']):
if self.state['obj_val'] < self.best_obj_value:
self.best_obj_value = self.state['obj_val']
self.best_solution = self.state['x_lp'].copy()
if self.node_queue:
# TODO: edge case when the agent reaches a leaf node.
# TODO: do we terminate? do we pop from candidate list?
# Transition to a node from the node queue
# self._pop_node_from_queue()
done = True
else:
# Node queue is empty, search finishes
done = True
else:
done = False
# Get reward
reward = self._get_reward()
if self.render:
self._visualize_search_tree()
return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value}
def _pop_node_from_queue(self):
# Transition to the next node in the node queue
next_node = self.node_queue.pop()
self.state['bounds'] = next_node['bounds']
self.state['x_lp'] = next_node['x_lp']
self.state['obj_val'] = next_node['obj_val']
# Update fractional variables and action mask #TODO change to .item
fractional_vars = [i for i in self.integer_indices if not np.isclose(
self.state['x_lp'][i], np.round(self.state['x_lp'][i]), atol=1e-5)]
self.state['fractional_vars'] = fractional_vars
self.action_mask = np.zeros(self.dim_x, dtype=np.int32)
self.action_mask[self.state['fractional_vars']] = 1
def _get_reward(self):
reward = self.prev_obj_val - self.state['obj_val']
self.prev_obj_val = self.state['obj_val']
return reward
def _is_integer_solution(self, x):
# Check if all variables are integers within a tolerance
return all(np.isclose(x[i], np.round(x[i]), atol=1e-5) for i in self.integer_indices)
def _check_terminal(self):
# Check if the current solution is integer feasible
if self._is_integer_solution(self.state['x_lp']) and len(self.node_queue) == 0:
print("Reason for completion: all required variables are integers and no nodes left in the queue")
return True
# Check if there are no valid actions left
if np.sum(self.action_mask) == 0:
print("Reason for completion: no valid actions left to take")
return True
if len(self.node_queue) == 0:
print("Reason for completion: no nodes left in the queue")
return True
return False
def _get_observation(self):
obs = {
'observation': self.state['x_lp'],
'action_mask': self.action_mask,
}
return obs
def _visualize_search_tree(self):
plt.figure(figsize=(10, 8))
pos = graphviz_layout(self.tree, prog="dot")
# Draw the graph with node labels
node_labels = nx.get_node_attributes(self.tree, 'label')
edge_labels = nx.get_edge_attributes(self.tree, 'label')
nx.draw(self.tree, pos, with_labels=True, labels=node_labels, node_color='skyblue', node_size=1200, font_size=10, font_weight='bold', edge_color='gray', arrows=True)
nx.draw_networkx_edge_labels(self.tree, pos, edge_labels=edge_labels, font_color='red', font_size=8)
plt.title("Branch and Bound Search Tree")
plt.show()