EECE571F-project / learn2branch_gasse / data.py
data.py
Raw
from torch_geometric.data import Data
from torch_geometric.data import Dataset
import torch
import numpy as np

class BipartiteData(Data):
    def __init__(
            self, 
            x_c: torch.Tensor, # constraint nodes
            x_v: torch.Tensor, # variable nodes
            edge_index: torch.Tensor, # edge index
            edge_features: torch.Tensor, # edge features
            candidates: torch.Tensor, # set of candidate variables to branch on
            candidates_scores: torch.Tensor, # branch scores for each candidate
            candidate_choice: torch.Tensor, # candidate picked by expert
            nb_candidates: int
    ):
        super().__init__()
        self.x_c = x_c
        self.x_v = x_v
        self.edge_index = edge_index
        self.edge_features = edge_features
        self.candidates = candidates
        self.candidates_scores = candidates_scores
        self.candidate_choice = candidate_choice
        self.nb_candidates = nb_candidates


        
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index': # change increment method for bipartite graphs
            return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]])
        if key == "candidates": # candidates change by the number of variables
            return self.x_v.size(0)
        return super().__inc__(key, value, *args, **kwargs)

class BipartiteDataset(Dataset):
    def __init__(self, samples: list):
        super().__init__()
        self.samples = samples
    def len(self):
        return len(self.samples)
    def get(self, index):
        sample_observation, sample_action, sample_action_set, sample_scores = self.samples[index]
        
        constraint_features = sample_observation.row_features
        edge_indices = sample_observation.edge_features.indices.astype(np.int32)
        edge_features = np.expand_dims(sample_observation.edge_features.values, axis=-1)
        variable_features = sample_observation.variable_features

        candidates = np.array(sample_action_set, dtype=np.int32)
        candidate_scores = np.array([sample_scores[j] for j in candidates])
        candidate_choice = np.where(candidates == sample_action)[0][0]

        graph = BipartiteData(
            torch.FloatTensor(constraint_features),
            torch.FloatTensor(variable_features),
            torch.LongTensor(edge_indices),
            torch.FloatTensor(edge_features),
            torch.LongTensor(candidates),
            torch.FloatTensor(candidate_scores),
            torch.LongTensor([candidate_choice]),
            len(candidates),
        )

        # We must tell pytorch geometric how many nodes there are, for indexing purposes
        graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]
        
        return graph


class BipartiteDataRL(Data):
    def __init__(
            self, 
            x_c: torch.Tensor, # constraint nodes
            x_v: torch.Tensor, # variable nodes
            edge_index: torch.Tensor, # edge index
            edge_features: torch.Tensor, # edge features
            candidates,
            nb_candidates
    ):
        super().__init__()
        self.x_c = x_c
        self.x_v = x_v
        self.edge_index = edge_index
        self.edge_features = edge_features
        self.candidates = candidates
        self.nb_candidates = nb_candidates
        
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index': # change increment method for bipartite graphs
            return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]])
        if key == "candidates": # candidates change by the number of variables
            return self.x_v.size(0)
        return super().__inc__(key, value, *args, **kwargs)
    
class BipartiteDatasetRL(Dataset):
    def __init__(self, state, action, reward, action_set, done):
        super().__init__()
        self.state = state
        self.action = action
        # self.next_state = next_state
        self.reward = reward
        self.action_set = action_set
        # self.next_action_set = next_action_set
        self.done = done
        
    def len(self):
        return len(self.state)
    
    def get(self, index):
        sample_observation = self.state[index]
        sample_action = self.action[index]
        # sample_next_observation = self.next_state[index]
        sample_reward = self.reward[index]
        sample_action_set = self.action_set[index]
        # sample_next_action_set = self.next_action_set[index]
        sample_done = self.done[index]

        # current state
        constraint_features, edge_indices, edge_features, variable_features = sample_observation
        graph = BipartiteDataRL(
            constraint_features,
            variable_features,
            edge_indices,
            edge_features,
            sample_action_set,
            len(sample_action_set)
        )
        # We must tell pytorch geometric how many nodes there are, for indexing purposes
        graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]
        
        return graph, sample_action, sample_reward, sample_done