EECE571F-project / learn2branch_self / agents / observations.py
observations.py
Raw
import torch
from torch_geometric.data import Data
import networkx as nx
import seaborn as sns
from networkx.drawing.nx_pydot import graphviz_layout
import matplotlib.pyplot as plt
from collections import defaultdict
from ordered_set import OrderedSet

class SearchTree:
    '''
    Tracks SCIP search tree. Call SearchTree.update_tree(ecole.Model) each
    time the ecole environment (and therefore the ecole.Model) is updated.
    N.B. SCIP does not store nodes which were pruned, infeasible, outside
    the search tree's optimality bounds, or which node was optimal, therefore these nodes will not be
    stored in the SearchTree. This is why m.getNTotalNodes() (the total number
    of nodes processed by SCIP) will likely be more than the number of nodes in
    the search tree when an instance is solved.
    '''
    def __init__(self, model):       
        self.tree = nx.DiGraph()

        self.tree.graph['root_node'] = None
        self.tree.graph['visited_nodes'] = []
        self.tree.graph['visited_node_ids'] = OrderedSet()
        self.tree.graph['optimum_nodes'] = []
        self.tree.graph['optimum_node_ids'] = OrderedSet()
        self.init_primal_bound = model.primal_bound
        self.tree.graph['incumbent_primal_bound'] = self.init_primal_bound

        self.tree.graph['fathomed_nodes'] = []
        self.tree.graph['fathomed_node_ids'] = OrderedSet()

        self.prev_primal_bound = None
        self.prev_node_id = None

        self.step_idx = 0
            
    def update_tree(self, model):
        '''
        Call this method after each update to the ecole environment. Pass
        the updated ecole.Model, and the B&B tree tracker will be updated accordingly.
        '''
        m = model.as_pyscipopt()
                
        # get current node (i.e. next node to be branched at)
        _curr_node = m.getCurrentNode()
        if _curr_node is not None:
            self.curr_node_id = _curr_node.getNumber()
        else:
            # branching finished, no curr node
            self.curr_node_id = None
        
        if len(self.tree.graph['visited_node_ids']) >= 1:
            self.prev_node_id, self.prev_node = self.tree.graph['visited_node_ids'][-1], self.tree.graph['visited_nodes'][-1]
            
            # check if previous branching at previous node changed global primal bound. If so, set previous node as optimum
            if m.getPrimalbound() < self.tree.graph['incumbent_primal_bound']:
                # branching at previous node led to finding new incumbent solution
                self.tree.graph['optimum_nodes'].append(self.prev_node)
                self.tree.graph['optimum_node_ids'].add(self.prev_node_id)
                self.tree.graph['incumbent_primal_bound'] = m.getPrimalbound()
            
        self.curr_node = {self.curr_node_id: _curr_node}
        if self.curr_node_id is not None:
            if self.curr_node_id not in self.tree.graph['visited_node_ids']:
                self._add_nodes(self.curr_node)
                self.tree.graph['visited_nodes'].append(self.curr_node)
                self.tree.graph['visited_node_ids'].add(self.curr_node_id)
                self.tree.nodes[self.curr_node_id]['step_visited'] = self.step_idx
        
        if self.curr_node_id is not None:
            _parent_node = list(self.curr_node.values())[0].getParent()
            if _parent_node is not None:
                parent_node_id = _parent_node.getNumber()
            else:
                # curr node is root node
                parent_node_id = None
            self.parent_node = {parent_node_id: _parent_node}
        else:
            self.parent_node = {None: None} 
            
        # add open nodes to tree
        open_leaves, open_children, open_siblings = m.getOpenNodes()
        self.open_leaves = {node.getNumber(): node  for node in open_leaves}
        self.open_children = {node.getNumber(): node for node in open_children}
        self.open_siblings = {node.getNumber(): node for node in open_siblings}
        
        self._add_nodes(self.open_leaves)
        self._add_nodes(self.open_children)
        self._add_nodes(self.open_siblings)
        
        # check if previous branching at previous node led to fathoming
        if len(self.tree.graph['visited_node_ids']) > 2 or self.curr_node_id is None:
            if self.curr_node_id is not None:
                # in above code, have added current node to visited node ids, therefore prev node is at idx=-2
                self.prev_node_id, self.prev_node = self.tree.graph['visited_node_ids'][-2], self.tree.graph['visited_nodes'][-2]
            else:
                # branching finished, previous node was fathomed
                self.prev_node_id, self.prev_node = self.tree.graph['visited_node_ids'][-1], self.tree.graph['visited_nodes'][-1]
            if len(list(self.tree.successors(self.prev_node_id))) == 0 and self.prev_node_id != self.curr_node_id:
                # branching at previous node led to fathoming
                self.tree.graph['fathomed_nodes'].append(self.prev_node)
                self.tree.graph['fathomed_node_ids'].add(self.prev_node_id)

        self.step_idx += 1

    def _add_nodes(self, nodes, parent_node_id=None):
        '''Adds nodes if not already in tree.'''
        for node_id, node in nodes.items():
            if node_id not in self.tree:
                # add node
                self.tree.add_node(node_id,
                                   _id=node_id,
                                   lower_bound=node.getLowerbound())

                # add edge
                _parent_node = node.getParent()
                if _parent_node is not None:
                    if parent_node_id is None:
                        parent_node_id = _parent_node.getNumber()
                    else:
                        # parent node id already given
                        pass
                    self.tree.add_edge(parent_node_id,
                                       node_id)
                else:
                    # is root node, has no parent
                    self.tree.graph['root_node'] = {node_id: node}
                    
    def _get_node_groups(self):
        node_groups = defaultdict(lambda: [])
        for node in self.tree.nodes:
            if node not in self.tree.graph['visited_node_ids'] or self.curr_node_id == node:
                node_groups['Unvisited'].append(node)
            else:
                node_groups['Visited'].append(node)
            if node in self.tree.graph['fathomed_node_ids']:
                node_groups['Fathomed'].append(node)
            if len(self.tree.graph['optimum_node_ids']) > 0:
                if node == self.tree.graph['optimum_node_ids'][-1]:
                    node_groups['Incumbent'].append(node)
        return node_groups
                                    
    def render(self,
               unvisited_node_colour='#FFFFFF',
               visited_node_colour='#A7C7E7',
               fathomed_node_colour='#FF6961',
               incumbent_node_colour='#C1E1C1',
               next_node_colour='#FFD700',
               node_edge_colour='#000000',
               use_latex_font=True,
               font_scale=0.75,
               context='paper',
               style='ticks'
              ):
        '''Renders B&B search tree.'''
        if use_latex_font:
            sns.set(rc={'text.usetex': True},
                    font='times')
        sns.set_theme(font_scale=font_scale, context=context, style=style)
        
        group_to_colour = {'Unvisited': unvisited_node_colour,
                           'Visited': visited_node_colour,
                           'Fathomed': fathomed_node_colour,
                           'Incumbent': incumbent_node_colour}
        
        f, ax = plt.subplots()
        
        pos = graphviz_layout(self.tree, prog='dot')

        node_groups = self._get_node_groups()
        for group_label, nodes in node_groups.items():
            nx.draw_networkx_nodes(self.tree,
                                   pos,
                                   nodelist=nodes,
                                   node_color=group_to_colour[group_label],
                                   edgecolors=node_edge_colour,
                                   label=group_label)
            
        if self.curr_node_id is not None:
            nx.draw_networkx_nodes(self.tree,
                                   pos,
                                   nodelist=[self.curr_node_id],
                                   node_color=unvisited_node_colour,
                                   edgecolors=next_node_colour,
                                   linewidths=3,
                                   label='Next')
            num_groups = len(list(node_groups.keys())) + 1
        else:
            num_groups = len(list(node_groups.keys()))
    
        nx.draw_networkx_edges(self.tree,
                               pos)
        
        nx.draw_networkx_labels(self.tree, pos, labels={node: node for node in self.tree.nodes})
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=num_groups)
        
        plt.show()