EECE571F-project / learn2branch_self / data.py
data.py
Raw
from torch_geometric.data import Data
from torch_geometric.data import Dataset
import torch

class BipartiteDataRL(Data):
    def __init__(
            self, 
            x_c: torch.Tensor, # constraint nodes
            x_v: torch.Tensor, # variable nodes
            edge_index: torch.Tensor, # edge index
            edge_features: torch.Tensor, # edge features
            candidates,
            nb_candidates
    ):
        super().__init__()
        self.x_c = x_c
        self.x_v = x_v
        self.edge_index = edge_index
        self.edge_features = edge_features
        self.candidates = candidates
        self.nb_candidates = nb_candidates
        
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index': # change increment method for bipartite graphs
            return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]])
        if key == "candidates": # candidates change by the number of variables
            return self.x_v.size(0)
        return super().__inc__(key, value, *args, **kwargs)
    
class BipartiteDatasetRL(Dataset):
    def __init__(self, state, action, reward, action_set, done, old_log_probs):
        super().__init__()
        self.state = state
        self.action = action
        self.reward = reward
        self.action_set = action_set
        self.done = done
        self.old_log_probs = old_log_probs
        
    def len(self):
        return len(self.state)
    
    def get(self, index):
        sample_observation = self.state[index]
        sample_action = self.action[index]
        sample_reward = self.reward[index]
        sample_action_set = self.action_set[index]
        sample_done = self.done[index]
        sample_old_log_probs = self.old_log_probs[index]

        # current state
        constraint_features, edge_indices, edge_features, variable_features = sample_observation
        graph = BipartiteDataRL(
            constraint_features,
            variable_features,
            edge_indices,
            edge_features,
            sample_action_set,
            len(sample_action_set)
        )
        # We must tell pytorch geometric how many nodes there are, for indexing purposes
        graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]
        
        return graph, sample_action, sample_reward, sample_done, sample_old_log_probs