import numpy as np from scipy.optimize import linprog import time import copy import torch import torch.nn as nn import torch.nn.functional as F class Node: def __init__(self, obj_val, x, bounds) -> None: self.obj_val = obj_val self.x = x self.bounds = bounds def BnB(c, A_ub=None, b_ub=None, A_eq=None, b_eq=None, bounds=(0, None), int_indices=[]): start = time.time() # start time sol_init = linprog(c=c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds) if not sol_init.success: print("No solution found") return None Q = [] best_so_far = Node(np.inf, np.zeros_like(c), bounds) Q.append(Node(sol_init.fun, sol_init.x, bounds)) tolerance = 1e-5 while len(Q) != 0: # This needs to be replaced by a policy that will choose which variable and constraint to branch on node = Q.pop() # pruning step (will need to be replaced) if node.obj_val >= best_so_far.obj_val: continue # record best solution so far if all(np.abs(node.x[i] - np.round(node.x[i])) <= tolerance for i in int_indices): if node.obj_val < best_so_far.obj_val: best_so_far = node else: # branch (would need to select best x here to branch on) for i in int_indices: if np.abs(node.x[i] - np.round(node.x[i])) > tolerance: floor_xi = np.floor(node.x[i]) # xi <= floor(x0) left_bounds = node.bounds.copy() left_bounds[i] = (left_bounds[i][0], floor_xi) left_child = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=left_bounds) if left_child.success and left_child.fun < best_so_far.obj_val: Q.append(Node(left_child.fun, left_child.x, left_bounds)) ceil_x0 = np.ceil(node.x[i]) # xi >= ceil(x0) right_bounds = node.bounds.copy() right_bounds[i] = (ceil_x0, right_bounds[i][1]) right_child = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=right_bounds) if right_child.success and right_child.fun < best_so_far.obj_val: Q.append(Node(right_child.fun, right_child.x, right_bounds)) break end = time.time() # end time time_to_solve = end-start return best_so_far, time_to_solve class NodeSelectionPolicy(nn.Module): def __init__(self, in_features, hidden_size, out_features): super(NodeSelectionPolicy, self).__init__() self.linear1 = nn.Linear(in_features, hidden_size) self.linear2 = nn.Linear(hidden_size, out_features) def forward(self, x): x = self.linear1(x) x = F.relu(x) x = self.linear2(x) return x def sample(self, logits): probs = torch.softmax(logits, dim=-1) smpl = torch.multinomial(probs, 1)[..., 0] action = F.one_hot(smpl, num_classes=-1) return action def BnB_RL(c, A_ub=None, b_ub=None, A_eq=None, b_eq=None, bounds=(0, None), int_indices=[]): start = time.time() # start time sol_init = linprog(c=c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds) if not sol_init.success: print("No solution found") return None Q = [] best_so_far = Node(np.inf, np.zeros_like(c), bounds) Q.append(Node(sol_init.fun, sol_init.x, bounds)) tolerance = 1e-5 n_vars_to_branch_on = 2 actor = NodeSelectionPolicy(n_vars_to_branch_on, hidden_size=16, out_features=n_vars_to_branch_on) while len(Q) != 0: # This needs to be replaced by a policy that will choose which variable and constraint to branch on node = Q.pop() # pruning step (will need to be replaced) if node.obj_val >= best_so_far.obj_val: continue # record best solution so far if all(np.abs(node.x[i] - np.round(node.x[i])) <= tolerance for i in int_indices): if node.obj_val < best_so_far.obj_val: best_so_far = node else: # TODO: logits = actor.forward() actor.sample() # branch (would need to select best x here to branch on) for i in int_indices: if np.abs(node.x[i] - np.round(node.x[i])) > tolerance: floor_xi = np.floor(node.x[i]) # xi <= floor(x0) left_bounds = node.bounds.copy() left_bounds[i] = (left_bounds[i][0], floor_xi) left_child = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=left_bounds) if left_child.success and left_child.fun < best_so_far.obj_val: Q.append(Node(left_child.fun, left_child.x, left_bounds)) ceil_x0 = np.ceil(node.x[i]) # xi >= ceil(x0) right_bounds = node.bounds.copy() right_bounds[i] = (ceil_x0, right_bounds[i][1]) right_child = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=right_bounds) if right_child.success and right_child.fun < best_so_far.obj_val: Q.append(Node(right_child.fun, right_child.x, right_bounds)) break end = time.time() # end time time_to_solve = end-start return best_so_far, time_to_solve