EECE571F-project / learn2branch_gasse / utils.py
utils.py
Raw
import torch
import torch.nn.functional as F
import os
import numpy as np

# source: https://github.com/ds4dm/ecole/blob/master/examples/branching-imitation/example.ipynb
def pad_tensor(input_, pad_sizes, pad_value=-1e8):
    """
    This utility function splits a tensor and pads each split to make them all the same size, then stacks them.
    """
    max_pad_size = pad_sizes.max() # for each example in the batch, we find the max number of candidates
    output = input_.split(pad_sizes.cpu().numpy().tolist()) # split the batched input logits into a tuple where each element has the correct number of logits for num variables
    output = torch.stack(
        [
            F.pad(slice_, (0, max_pad_size - slice_.size(0)), "constant", pad_value) # right pad max_pad_size - slice_.size(0) elements
            for slice_ in output
        ],
        dim=0,
    )
    return output

def create_folder(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)
        print(f"Created: {dirname}")

def to_tensor(np_arr, dtype=torch.float32):
    return torch.tensor(np_arr, dtype=dtype)


def to_state(obs, device):
    obs = (
        torch.from_numpy(obs.row_features.astype(np.float32)).to(device),
        torch.from_numpy(obs.edge_features.indices.astype(np.int64)).to(device),
        torch.from_numpy(obs.edge_features.values.astype(np.float32)).view(-1, 1).to(device),
        torch.from_numpy(obs.variable_features.astype(np.float32)).to(device),
    )
    return obs