import torch from torch_geometric.data import Data import networkx as nx import seaborn as sns from networkx.drawing.nx_pydot import graphviz_layout import matplotlib.pyplot as plt from collections import defaultdict from ordered_set import OrderedSet class SearchTree: ''' Tracks SCIP search tree. Call SearchTree.update_tree(ecole.Model) each time the ecole environment (and therefore the ecole.Model) is updated. N.B. SCIP does not store nodes which were pruned, infeasible, outside the search tree's optimality bounds, or which node was optimal, therefore these nodes will not be stored in the SearchTree. This is why m.getNTotalNodes() (the total number of nodes processed by SCIP) will likely be more than the number of nodes in the search tree when an instance is solved. ''' def __init__(self, model): self.tree = nx.DiGraph() self.tree.graph['root_node'] = None self.tree.graph['visited_nodes'] = [] self.tree.graph['visited_node_ids'] = OrderedSet() self.tree.graph['optimum_nodes'] = [] self.tree.graph['optimum_node_ids'] = OrderedSet() self.init_primal_bound = model.primal_bound self.tree.graph['incumbent_primal_bound'] = self.init_primal_bound self.tree.graph['fathomed_nodes'] = [] self.tree.graph['fathomed_node_ids'] = OrderedSet() self.prev_primal_bound = None self.prev_node_id = None self.step_idx = 0 def update_tree(self, model): ''' Call this method after each update to the ecole environment. Pass the updated ecole.Model, and the B&B tree tracker will be updated accordingly. ''' m = model.as_pyscipopt() # get current node (i.e. next node to be branched at) _curr_node = m.getCurrentNode() if _curr_node is not None: self.curr_node_id = _curr_node.getNumber() else: # branching finished, no curr node self.curr_node_id = None if len(self.tree.graph['visited_node_ids']) >= 1: self.prev_node_id, self.prev_node = self.tree.graph['visited_node_ids'][-1], self.tree.graph['visited_nodes'][-1] # check if previous branching at previous node changed global primal bound. If so, set previous node as optimum if m.getPrimalbound() < self.tree.graph['incumbent_primal_bound']: # branching at previous node led to finding new incumbent solution self.tree.graph['optimum_nodes'].append(self.prev_node) self.tree.graph['optimum_node_ids'].add(self.prev_node_id) self.tree.graph['incumbent_primal_bound'] = m.getPrimalbound() self.curr_node = {self.curr_node_id: _curr_node} if self.curr_node_id is not None: if self.curr_node_id not in self.tree.graph['visited_node_ids']: self._add_nodes(self.curr_node) self.tree.graph['visited_nodes'].append(self.curr_node) self.tree.graph['visited_node_ids'].add(self.curr_node_id) self.tree.nodes[self.curr_node_id]['step_visited'] = self.step_idx if self.curr_node_id is not None: _parent_node = list(self.curr_node.values())[0].getParent() if _parent_node is not None: parent_node_id = _parent_node.getNumber() else: # curr node is root node parent_node_id = None self.parent_node = {parent_node_id: _parent_node} else: self.parent_node = {None: None} # add open nodes to tree open_leaves, open_children, open_siblings = m.getOpenNodes() self.open_leaves = {node.getNumber(): node for node in open_leaves} self.open_children = {node.getNumber(): node for node in open_children} self.open_siblings = {node.getNumber(): node for node in open_siblings} self._add_nodes(self.open_leaves) self._add_nodes(self.open_children) self._add_nodes(self.open_siblings) # check if previous branching at previous node led to fathoming if len(self.tree.graph['visited_node_ids']) > 2 or self.curr_node_id is None: if self.curr_node_id is not None: # in above code, have added current node to visited node ids, therefore prev node is at idx=-2 self.prev_node_id, self.prev_node = self.tree.graph['visited_node_ids'][-2], self.tree.graph['visited_nodes'][-2] else: # branching finished, previous node was fathomed self.prev_node_id, self.prev_node = self.tree.graph['visited_node_ids'][-1], self.tree.graph['visited_nodes'][-1] if len(list(self.tree.successors(self.prev_node_id))) == 0 and self.prev_node_id != self.curr_node_id: # branching at previous node led to fathoming self.tree.graph['fathomed_nodes'].append(self.prev_node) self.tree.graph['fathomed_node_ids'].add(self.prev_node_id) self.step_idx += 1 def _add_nodes(self, nodes, parent_node_id=None): '''Adds nodes if not already in tree.''' for node_id, node in nodes.items(): if node_id not in self.tree: # add node self.tree.add_node(node_id, _id=node_id, lower_bound=node.getLowerbound()) # add edge _parent_node = node.getParent() if _parent_node is not None: if parent_node_id is None: parent_node_id = _parent_node.getNumber() else: # parent node id already given pass self.tree.add_edge(parent_node_id, node_id) else: # is root node, has no parent self.tree.graph['root_node'] = {node_id: node} def _get_node_groups(self): node_groups = defaultdict(lambda: []) for node in self.tree.nodes: if node not in self.tree.graph['visited_node_ids'] or self.curr_node_id == node: node_groups['Unvisited'].append(node) else: node_groups['Visited'].append(node) if node in self.tree.graph['fathomed_node_ids']: node_groups['Fathomed'].append(node) if len(self.tree.graph['optimum_node_ids']) > 0: if node == self.tree.graph['optimum_node_ids'][-1]: node_groups['Incumbent'].append(node) return node_groups def render(self, unvisited_node_colour='#FFFFFF', visited_node_colour='#A7C7E7', fathomed_node_colour='#FF6961', incumbent_node_colour='#C1E1C1', next_node_colour='#FFD700', node_edge_colour='#000000', use_latex_font=True, font_scale=0.75, context='paper', style='ticks' ): '''Renders B&B search tree.''' if use_latex_font: sns.set(rc={'text.usetex': True}, font='times') sns.set_theme(font_scale=font_scale, context=context, style=style) group_to_colour = {'Unvisited': unvisited_node_colour, 'Visited': visited_node_colour, 'Fathomed': fathomed_node_colour, 'Incumbent': incumbent_node_colour} f, ax = plt.subplots() pos = graphviz_layout(self.tree, prog='dot') node_groups = self._get_node_groups() for group_label, nodes in node_groups.items(): nx.draw_networkx_nodes(self.tree, pos, nodelist=nodes, node_color=group_to_colour[group_label], edgecolors=node_edge_colour, label=group_label) if self.curr_node_id is not None: nx.draw_networkx_nodes(self.tree, pos, nodelist=[self.curr_node_id], node_color=unvisited_node_colour, edgecolors=next_node_colour, linewidths=3, label='Next') num_groups = len(list(node_groups.keys())) + 1 else: num_groups = len(list(node_groups.keys())) nx.draw_networkx_edges(self.tree, pos) nx.draw_networkx_labels(self.tree, pos, labels={node: node for node in self.tree.nodes}) plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=num_groups) plt.show()