EECE571F-project / RL_self_implemented / environment.py
environment.py
Raw
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from scipy.optimize import linprog
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout


class BranchAndBoundEnv(gym.Env):
    def __init__(self, c, A_ub, b_ub, A_eq, b_eq, bounds, integer_indices, render=False):
        super(BranchAndBoundEnv, self).__init__()
        self.c = c
        self.A_ub = A_ub
        self.b_ub = b_ub
        self.A_eq = A_eq
        self.b_eq = b_eq
        self.bounds = bounds
        self.dim_x = len(c)
        self.integer_indices = integer_indices
        self.render = render

        # self.constraint_dim = A_ub.shape[0]
        self.max_candidates = len(integer_indices)  # Max number of candidate variables to branch on

        # Define action space (variables to branch on)
        self.action_space = spaces.Discrete(self.dim_x)

        # Define observation space (LP solution and action mask)
        self.observation_space = spaces.Dict({
            'observation': spaces.Box(low=-np.inf, high=np.inf, shape=(self.dim_x,)),
            'action_mask': spaces.Box(low=0, high=1, shape=(self.dim_x,), dtype=np.int32)
        })

        self.state = None
        self.action_mask = None

        if self.render:
            ## visualization tools ##
            # Search tree setup
            self.tree = nx.DiGraph()  # Directed graph for the tree
            self.node_counter = 0     # Unique counter for node IDs
            self.parent_stack = []    # Keep track of parent nodes for branching

    
    def reset(self):
        self.state, self.action_mask = self._get_initial_state()
        self.best_obj_value = np.inf
        self.node_queue = []
        return self._get_observation()

    
    def _get_initial_state(self):
        # Solve the initial LP relaxation
        res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=self.bounds)
        x_lp = res.x
        obj_val = res.fun

        # Identify fractional variables
        fractional_vars = [i.item() for i in self.integer_indices if not np.isclose(x_lp[i], np.round(x_lp[i]))]

        # Create action mask
        action_mask = np.zeros(self.dim_x, dtype=np.int32)
        action_mask[fractional_vars] = 1  # Only the first num_candidates are valid

        self.prev_obj_val = obj_val

        # Store state
        state = {
            'x_lp': x_lp.copy(),
            'obj_val': obj_val,
            'fractional_vars': fractional_vars,
            'bounds': self.bounds
        }

        if self.render:
            ## visualization tools ##
            # Add the root node to the tree
            self.tree.add_node(self.node_counter, label="Root", bounds=self.bounds, obj_val=obj_val)
            self.current_node_id = self.node_counter
            self.node_counter += 1

        return state, action_mask

    def step(self, action):
        var_idx = self.state['fractional_vars'][action]

        if not self.action_mask[var_idx]:
            raise ValueError("Selected action is invalid.")

        # Get the variable index to branch on
        x_i = self.state['x_lp'][var_idx]
        floor_x_i = np.floor(x_i).item()
        ceil_x_i = np.ceil(x_i).item()

        # Copy current bounds
        old_bounds = self.state['bounds']

        # Initialize a list to store valid child nodes
        valid_children = []

        ## Left Child Node (x_i <= floor_x_i) ##
        left_bounds = old_bounds.copy()
        left_bounds[var_idx] = (left_bounds[var_idx][0], floor_x_i)

        # Solve LP for left child
        left_res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=left_bounds)
        if left_res.status == 0:  # If left child is feasible
            left_obj_val = left_res.fun
            if left_obj_val < self.best_obj_value: # make sure that left child's lower bound is valid
                valid_children.append({
                    'bounds': left_bounds,
                    'x_lp': left_res.x,
                    'obj_val': left_obj_val
                })

        ## Right Child Node (x_i >= ceil_x_i) ##
        right_bounds = old_bounds.copy()
        right_bounds[var_idx] = (ceil_x_i, right_bounds[var_idx][1])

        # Solve LP for right child
        right_res = linprog(c=self.c, A_ub=self.A_ub, b_ub=self.b_ub, A_eq=self.A_eq, b_eq=self.b_eq, bounds=right_bounds)
        if right_res.status == 0: # If right child is feasible
            right_obj_val = right_res.fun
            if right_obj_val < self.best_obj_value: # make sure that lower bound is valid
                valid_children.append({
                    'bounds': right_bounds,
                    'x_lp': right_res.x,
                    'obj_val': right_obj_val
                })

        if len(valid_children) == 0:
            # No feasible children from branching on this variable
            # Remove the variable from the action mask
            self.action_mask[var_idx] = 0
            self.state['fractional_vars'].remove(var_idx)

            if np.sum(self.action_mask) > 0: # Case: when more than one variable can be branched on, but the agent's first choice was poor.
                # There are still fractional variables to branch on in this node
                reward = -1  # Small penalty for choosing a dead-end variable
                done = False
                return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value}
            else: # Case: when there is only one variable to branch on but both branching solutions are infeasible
                # No more variables to branch on in this node
                if self.node_queue:
                    # Transition to a node from the node queue
                    self._pop_node_from_queue()
                    reward = -1  # Small penalty for backtracking
                    done = False
                    return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value}
                else:
                    # Node queue is empty, search finishes
                    done = True
                    reward = 0
                    return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value}
        else:

            if self.render:
                ## visualization tools ##
                # Add children to the tree and select a child node
                child_id = []
                for child in valid_children:
                    new_node_id = self.node_counter
                    self.tree.add_node(new_node_id, label=f"Node {child['obj_val']}", bounds=child['bounds'], obj_val=child['obj_val'])
                    self.tree.add_edge(self.current_node_id, new_node_id, label=child['bounds'])
                    self.node_counter += 1
                    child_id.append(new_node_id)


            # (for now) randomly select a child node to branch
            selected_child = np.random.choice(valid_children)
            
            # Add other children to the node queue
            for id, child in enumerate(valid_children):
                if child != selected_child:
                    self.node_queue.append(child)
                elif self.render:
                    self.current_node_id = child_id[id]


            # State transition to the selected child node
            self.state['bounds'] = selected_child['bounds']
            self.state['x_lp'] = selected_child['x_lp']
            self.state['obj_val'] = selected_child['obj_val']

            # Update fractional variables and action mask
            fractional_vars = [i for i in self.integer_indices if not np.isclose(
                self.state['x_lp'][i], np.round(self.state['x_lp'][i]), atol=1e-5)]
            self.state['fractional_vars'] = fractional_vars
            self.action_mask = np.zeros(self.dim_x, dtype=np.int32)
            self.action_mask[self.state['fractional_vars']] = 1

            # Update the best known objective value if an integer solution is found
            if self._is_integer_solution(self.state['x_lp']):
                if self.state['obj_val'] < self.best_obj_value:
                    self.best_obj_value = self.state['obj_val']
                    self.best_solution = self.state['x_lp'].copy()

                if self.node_queue:
                    # TODO: edge case when the agent reaches a leaf node.
                    # TODO: do we terminate? do we pop from candidate list?

                    # Transition to a node from the node queue
                    # self._pop_node_from_queue()
                    done = True
                else:
                    # Node queue is empty, search finishes
                    done = True
            else:
                done = False

            # Get reward
            reward = self._get_reward()

            if self.render:
                self._visualize_search_tree()

            return self._get_observation(), reward, done, {'best_obj_value': self.best_obj_value}
    
    def _pop_node_from_queue(self):
        # Transition to the next node in the node queue
        next_node = self.node_queue.pop()
        self.state['bounds'] = next_node['bounds']
        self.state['x_lp'] = next_node['x_lp']
        self.state['obj_val'] = next_node['obj_val']

        # Update fractional variables and action mask #TODO change to .item
        fractional_vars = [i for i in self.integer_indices if not np.isclose(
            self.state['x_lp'][i], np.round(self.state['x_lp'][i]), atol=1e-5)]
        self.state['fractional_vars'] = fractional_vars
        self.action_mask = np.zeros(self.dim_x, dtype=np.int32)
        self.action_mask[self.state['fractional_vars']] = 1

    def _get_reward(self):
        reward = self.prev_obj_val - self.state['obj_val']
        self.prev_obj_val = self.state['obj_val']
        return reward
    
    def _is_integer_solution(self, x):
        # Check if all variables are integers within a tolerance
        return all(np.isclose(x[i], np.round(x[i]), atol=1e-5) for i in self.integer_indices)

    def _check_terminal(self):
        # Check if the current solution is integer feasible
        if self._is_integer_solution(self.state['x_lp']) and len(self.node_queue) == 0:
            print("Reason for completion: all required variables are integers and no nodes left in the queue")
            return True
        # Check if there are no valid actions left
        if np.sum(self.action_mask) == 0:
            print("Reason for completion: no valid actions left to take")
            return True
        if len(self.node_queue) == 0:
            print("Reason for completion: no nodes left in the queue")
            return True
        return False
        
    def _get_observation(self):
        obs = {
            'observation': self.state['x_lp'],
            'action_mask': self.action_mask,
        }
        return obs

    def _visualize_search_tree(self):
        plt.figure(figsize=(10, 8))
        pos = graphviz_layout(self.tree, prog="dot")
        # Draw the graph with node labels
        node_labels = nx.get_node_attributes(self.tree, 'label')
        edge_labels = nx.get_edge_attributes(self.tree, 'label')

        nx.draw(self.tree, pos, with_labels=True, labels=node_labels, node_color='skyblue', node_size=1200, font_size=10, font_weight='bold', edge_color='gray', arrows=True)
        nx.draw_networkx_edge_labels(self.tree, pos, edge_labels=edge_labels, font_color='red', font_size=8)

        plt.title("Branch and Bound Search Tree")
        plt.show()