import numpy as np
from scipy.optimize import linprog
import time
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
class Node:
def __init__(self, obj_val, x, bounds) -> None:
self.obj_val = obj_val
self.x = x
self.bounds = bounds
def BnB(c, A_ub=None, b_ub=None, A_eq=None, b_eq=None, bounds=(0, None), int_indices=[]):
start = time.time() # start time
sol_init = linprog(c=c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds)
if not sol_init.success:
print("No solution found")
return None
Q = []
best_so_far = Node(np.inf, np.zeros_like(c), bounds)
Q.append(Node(sol_init.fun, sol_init.x, bounds))
tolerance = 1e-5
while len(Q) != 0:
# This needs to be replaced by a policy that will choose which variable and constraint to branch on
node = Q.pop()
# pruning step (will need to be replaced)
if node.obj_val >= best_so_far.obj_val:
continue
# record best solution so far
if all(np.abs(node.x[i] - np.round(node.x[i])) <= tolerance for i in int_indices):
if node.obj_val < best_so_far.obj_val:
best_so_far = node
else:
# branch (would need to select best x here to branch on)
for i in int_indices:
if np.abs(node.x[i] - np.round(node.x[i])) > tolerance:
floor_xi = np.floor(node.x[i]) # xi <= floor(x0)
left_bounds = node.bounds.copy()
left_bounds[i] = (left_bounds[i][0], floor_xi)
left_child = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=left_bounds)
if left_child.success and left_child.fun < best_so_far.obj_val:
Q.append(Node(left_child.fun, left_child.x, left_bounds))
ceil_x0 = np.ceil(node.x[i]) # xi >= ceil(x0)
right_bounds = node.bounds.copy()
right_bounds[i] = (ceil_x0, right_bounds[i][1])
right_child = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=right_bounds)
if right_child.success and right_child.fun < best_so_far.obj_val:
Q.append(Node(right_child.fun, right_child.x, right_bounds))
break
end = time.time() # end time
time_to_solve = end-start
return best_so_far, time_to_solve
class NodeSelectionPolicy(nn.Module):
def __init__(self, in_features, hidden_size, out_features):
super(NodeSelectionPolicy, self).__init__()
self.linear1 = nn.Linear(in_features, hidden_size)
self.linear2 = nn.Linear(hidden_size, out_features)
def forward(self, x):
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
return x
def sample(self, logits):
probs = torch.softmax(logits, dim=-1)
smpl = torch.multinomial(probs, 1)[..., 0]
action = F.one_hot(smpl, num_classes=-1)
return action
def BnB_RL(c, A_ub=None, b_ub=None, A_eq=None, b_eq=None, bounds=(0, None), int_indices=[]):
start = time.time() # start time
sol_init = linprog(c=c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds)
if not sol_init.success:
print("No solution found")
return None
Q = []
best_so_far = Node(np.inf, np.zeros_like(c), bounds)
Q.append(Node(sol_init.fun, sol_init.x, bounds))
tolerance = 1e-5
n_vars_to_branch_on = 2
actor = NodeSelectionPolicy(n_vars_to_branch_on, hidden_size=16, out_features=n_vars_to_branch_on)
while len(Q) != 0:
# This needs to be replaced by a policy that will choose which variable and constraint to branch on
node = Q.pop()
# pruning step (will need to be replaced)
if node.obj_val >= best_so_far.obj_val:
continue
# record best solution so far
if all(np.abs(node.x[i] - np.round(node.x[i])) <= tolerance for i in int_indices):
if node.obj_val < best_so_far.obj_val:
best_so_far = node
else:
# TODO:
logits = actor.forward()
actor.sample()
# branch (would need to select best x here to branch on)
for i in int_indices:
if np.abs(node.x[i] - np.round(node.x[i])) > tolerance:
floor_xi = np.floor(node.x[i]) # xi <= floor(x0)
left_bounds = node.bounds.copy()
left_bounds[i] = (left_bounds[i][0], floor_xi)
left_child = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=left_bounds)
if left_child.success and left_child.fun < best_so_far.obj_val:
Q.append(Node(left_child.fun, left_child.x, left_bounds))
ceil_x0 = np.ceil(node.x[i]) # xi >= ceil(x0)
right_bounds = node.bounds.copy()
right_bounds[i] = (ceil_x0, right_bounds[i][1])
right_child = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=right_bounds)
if right_child.success and right_child.fun < best_so_far.obj_val:
Q.append(Node(right_child.fun, right_child.x, right_bounds))
break
end = time.time() # end time
time_to_solve = end-start
return best_so_far, time_to_solve