CeFGC / lib / data_oneDS_ours.py
data_oneDS_ours.py
Raw
import random
from random import choices
import numpy as np
import pandas as pd

import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import OneHotDegree

from models_ import GIN, serverGIN, GIN_dc, serverGIN_dc
# from server import Server
from server_ import Server
from client_ours import Client_GC
from utils import get_maxDegree, get_stats, split_data, get_numGraphLabels, init_structure_encoding

from scipy.special import rel_entr
import scipy
from torch_geometric.utils import erdos_renyi_graph, degree
import itertools
from torch_geometric.utils.convert import from_networkx
import os
import pickle
import networkx as nx
from torch_geometric.data import Data
import copy as cp

def data_oneDS_ours(args,dataset,ft_dim):
    # print(len(dataset))
    division = int(len(dataset) * args.split / args.clients)
    # print(division)
    # division=int(division/4)

    division = int(division / 4)
    # print(division)

    if args.alg == 'CeFGC*':
        if args.data_group=='MUTAG':
            res='./data/mutag_'
            graph_s_gen_dir0_0 = res+"/62-28-0-0-3/sample/sample_data_0_0/"
            graph_s_gen_dir0_1 = res+"/62-28-0-0-3/sample/sample_data_0_1/"
            graph_s_gen_dir1_0 = res+"/62-28-1-0-3/sample/sample_data_1_0/"
            graph_s_gen_dir1_1 = res+"/62-28-1-0-3/sample/sample_data_1_1/"
            graph_s_gen_dir2_0 = res+"/62-26-2-0-3/sample/sample_data_2_0/"
            graph_s_gen_dir2_1 = res+"/62-26-2-0-3/sample/sample_data_2_1/"

            graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
            graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
            num_data_classes=2
            graph_s_gen_dirs_list=[graph_s_gen_dirs0,graph_s_gen_dirs1]


        elif args.data_group=='IMDB-BINARY':
            res = './data/imdb-binary_'
            graph_s_gen_dir0_0 = res+"/266-72-0-0-3/sample/sample_data_0_0/"
            graph_s_gen_dir0_1 = res+"/266-72-0-0-3/sample/sample_data_0_1/"
            graph_s_gen_dir1_0 = res+"/266-65-1-0-3/sample/sample_data_1_0/"
            graph_s_gen_dir1_1 = res+"/266-65-1-0-3/sample/sample_data_1_1/"
            graph_s_gen_dir2_0 = res+"/266-56-2-0-3/sample/sample_data_2_0/"
            graph_s_gen_dir2_1 = res+"/266-56-2-0-3/sample/sample_data_2_1/"
            graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
            graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
            num_data_classes = 2
            graph_s_gen_dirs_list=[graph_s_gen_dirs0,graph_s_gen_dirs1]

    elif args.alg == 'CeFGC':
        if args.data_group == 'MUTAG':
            res='./data/mutag'
            graph_s_gen_dir0_0 = res+"/62-28-0-0/sample/sample_data/"
            graph_s_gen_dir0_1 = res+"/62-21-0-1/sample/sample_data/"
            graph_s_gen_dir1_0 = res+"/62-28-1-0/sample/sample_data/"
            graph_s_gen_dir1_1 = res+"/62-24-1-1/sample/sample_data/"
            graph_s_gen_dir2_0 = res+"/62-26-2-0/sample/sample_data/"
            graph_s_gen_dir2_1 = res+"/62-22-2-1/sample/sample_data/"

            graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
            graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
            num_data_classes = 2
            graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]


        elif args.data_group == 'IMDB-BINARY':
            res = './data/imdb-binary'
            graph_s_gen_dir0_0 = res+"266-49-0-0-3/sample/sample_data/"
            graph_s_gen_dir0_1 = res+"266-72-0-1-3/sample/sample_data/"
            graph_s_gen_dir1_0 = res+"266-59-1-0-3/sample/sample_data/"
            graph_s_gen_dir1_1 = res+"266-65-1-1-3/sample/sample_data/"
            graph_s_gen_dir2_0 = res+"266-55-2-0-3/sample/sample_data/"
            graph_s_gen_dir2_1 = res+"266-56-2-1-3/sample/sample_data/"
            graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
            graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
            num_data_classes = 2
            graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]

    # elif args.alg == 'CeFGC*+':
    #     if args.data_group == 'MUTAG':
    #         res = './data/mutag'
    #         graph_s_gen_dir0_0 = res + "/62-28-0-0-3/sample/sample_data_0_0/"
    #         graph_s_gen_dir0_1 = res + "/62-28-0-0-3/sample/sample_data_0_1/"
    #         graph_s_gen_dir1_0 = res + "/62-28-1-0-3/sample/sample_data_1_0/"
    #         graph_s_gen_dir1_1 = res + "/62-28-1-0-3/sample/sample_data_1_1/"
    #         graph_s_gen_dir2_0 = res + "/62-26-2-0-3/sample/sample_data_2_0/"
    #         graph_s_gen_dir2_1 = res + "/62-26-2-0-3/sample/sample_data_2_1/"
    #
    #         graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
    #         graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
    #         num_data_classes = 2
    #         graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
    #
    #
    #     elif args.data_group == 'IMDB-BINARY':
    #         res = './data/imdb-binary'
    #         graph_s_gen_dir0_0 = res + "/266-72-0-0-3/sample/sample_data_0_0/"
    #         graph_s_gen_dir0_1 = res + "/266-72-0-0-3/sample/sample_data_0_1/"
    #         graph_s_gen_dir1_0 = res + "/266-65-1-0-3/sample/sample_data_1_0/"
    #         graph_s_gen_dir1_1 = res + "/266-65-1-0-3/sample/sample_data_1_1/"
    #         graph_s_gen_dir2_0 = res + "/266-56-2-0-3/sample/sample_data_2_0/"
    #         graph_s_gen_dir2_1 = res + "/266-56-2-0-3/sample/sample_data_2_1/"
    #         graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
    #         graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
    #         num_data_classes = 2
    #         graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
    #
    # elif args.alg == 'CeFGC+':
    #     if args.data_group == 'MUTAG':
    #         res = './data/mutag'
    #         graph_s_gen_dir0_0 = res + "/62-28-0-0/sample/sample_data/"
    #         graph_s_gen_dir0_1 = res + "/62-21-0-1/sample/sample_data/"
    #         graph_s_gen_dir1_0 = res + "/62-28-1-0/sample/sample_data/"
    #         graph_s_gen_dir1_1 = res + "/62-24-1-1/sample/sample_data/"
    #         graph_s_gen_dir2_0 = res + "/62-26-2-0/sample/sample_data/"
    #         graph_s_gen_dir2_1 = res + "/62-22-2-1/sample/sample_data/"
    #
    #         graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
    #         graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
    #         num_data_classes = 2
    #         graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
    #
    #
    #     elif args.data_group == 'IMDB-BINARY':
    #         res = './data/imdb-binary'
    #         graph_s_gen_dir0_0 = res + "266-49-0-0-3/sample/sample_data/"
    #         graph_s_gen_dir0_1 = res + "266-72-0-1-3/sample/sample_data/"
    #         graph_s_gen_dir1_0 = res + "266-59-1-0-3/sample/sample_data/"
    #         graph_s_gen_dir1_1 = res + "266-65-1-1-3/sample/sample_data/"
    #         graph_s_gen_dir2_0 = res + "266-55-2-0-3/sample/sample_data/"
    #         graph_s_gen_dir2_1 = res + "266-56-2-1-3/sample/sample_data/"
    #         graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
    #         graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
    #         num_data_classes = 2
    #         graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]

    graph_s_gen_list_dic={}
    edge_list_gen_dic={}
    num_gen_dic={}
    x_list_dic={}
    nd_gen_dic={}
    for iclass in range(num_data_classes):
        # graph_s_gen_list0 = []
        graph_s_gen_list_dic[iclass] = []
        # graph_s_gen_list1 = []
        # edge_list_gen0=[]
        # edge_list_gen1 = []
        edge_list_gen_dic[iclass] = []
        num_gen_dic[iclass]=[]

        # num_gen0 = []
        # num_gen1 = []
        nd_gen_dic[iclass] = []
        x_list_dic[iclass]=[]
        # x_list0 = []
        # x_list1 = []
        random.seed(args.seed)

        for graph_s_gen_dir0 in graph_s_gen_dirs_list[iclass]:

            graph_s_gen_dir = graph_s_gen_dir0
            files = os.listdir(graph_s_gen_dir)
            # print(files)
            for file in files:
                graph_s_gen_dir_f = os.path.join(graph_s_gen_dir, file)
                f2 = open(graph_s_gen_dir_f, 'rb')
                graph_s_gen = pickle.load(f2, encoding='latin1')
                random.seed(args.seed)
                # graph_s_gen = random.sample(graph_s_gen, division)
                #
                # print(len(graph_s_gen))

                graph_s_gen_list_dic[iclass].append(graph_s_gen)

                g_edges_list = []
                nd_gen_num = []
                x_list = []
                gen_num = 0
                for g_ in graph_s_gen:
                    g_edges = g_.edges()
                    g_edges = np.array([[nd1, nd2] for (nd1, nd2) in g_edges]).T

                    if g_.number_of_nodes() == 1:
                        continue

                    if (np.max(g_edges)) > g_.number_of_nodes() or (np.max(g_edges)) == g_.number_of_nodes():
                        num_tmp = np.max(g_edges) + 1

                    else:
                        num_tmp = g_.number_of_nodes()

                    g_edges_list.append(g_edges)
                    nd_gen_num.append(num_tmp)
                    x_list.append(np.random.randint(0, 2, (num_tmp, ft_dim)))

                    gen_num += 1

                    if gen_num >= division:
                        # print('****')
                        break

                edge_list_gen_dic[iclass].append(g_edges_list)
                num_gen_dic[iclass].append(gen_num)
                nd_gen_dic[iclass].append(nd_gen_num)
                x_list_dic[iclass].append(x_list)


    g_gen = {}
    y_gen = {}
    x_gen = {}
    edge_list_gen = {}
    g_gen_pyg = {}

    for clt1 in range(args.clients):
        g_gen[clt1] = []
        g_gen_pyg[clt1] = []
        edge_list_gen[clt1] = []
        x_gen[clt1] = []
        y_gen[clt1] = []
        num_gen = 0
        for clt2 in range(args.clients):
            if clt2 != clt1:
                for jclass in range(num_data_classes):
                    g_gen[clt1].append(list(graph_s_gen_list_dic[jclass][clt2]))
                    # g_gen[clt1].append(list(graph_s_gen_list1[clt2]))

                    edge_list_gen[clt1].append(list(edge_list_gen_dic[jclass][clt2]))
                    # edge_list_gen[clt1].append(list(edge_list_gen1[clt2]))

                    y_gen[clt1].append([jclass] * num_gen_dic[jclass][clt2])
                    # y_gen[clt1].append([1] * num_gen1[clt2])

                    x_gen[clt1].append(list(x_list_dic[jclass][clt2]))
                    # x_gen[clt1].append(list(x_list1[clt2]))

                    num_gen += num_gen_dic[jclass][clt2]
                    # num_gen += num_gen1[clt2]
        # print(list(itertools.chain.from_iterable(edge_list_gen[clt1])))
        edge_list_gen[clt1] = list(itertools.chain.from_iterable(edge_list_gen[clt1]))
        y_gen[clt1] = list(itertools.chain.from_iterable(y_gen[clt1]))
        x_gen[clt1] = list(itertools.chain.from_iterable(x_gen[clt1]))
        idxs = list(range(num_gen))
        random.shuffle(idxs)

        tmp_edge = []
        tmp_y = []
        tmp_x = []
        tmp_g=[]
        # print(np.shape(np.array(y_gen[clt1])),np.shape(np.array(x_gen[clt1])))
        # print(edge_list_gen[clt1][0])
        for idx in idxs:
            tmp_edge.append(torch.tensor(edge_list_gen[clt1][idx]))
            tmp_y.append(torch.tensor(y_gen[clt1][idx]))
            tmp_x.append(torch.tensor(x_gen[clt1][idx]))

            edge_index = torch.tensor(edge_list_gen[clt1][idx], dtype=torch.long)
            x = torch.tensor(x_gen[clt1][idx], dtype=torch.float)
            y = torch.tensor(y_gen[clt1][idx], dtype=torch.long)
            pyg_graph = Data(x=x, edge_index=edge_index, y=y)

            tmp_g.append(pyg_graph)



        edge_list_gen[clt1] = tmp_edge
        y_gen[clt1] = tmp_y
        # print(y_gen[clt1])
        x_gen[clt1] = tmp_x
        # print('***',x_gen[clt1])
        # exit()
        g_gen_pyg[clt1]=tmp_g

    torch.manual_seed(12345)
    dataset = dataset.shuffle()
    print('!!!', len(dataset))

    data_dir = '/data/'
    file_name = args.dataset.lower() + '-train_feats_label_edge_list'
    file_path = os.path.join(data_dir, file_name)
    # feature_set = set()
    with open(file_path, 'rb') as f:
        feats_list_train0, label_list_train0, edge_list_train0 = pickle.load(f)

        for i in range(np.shape(label_list_train0)[0]):

            if np.max(edge_list_train0[i]) >= np.shape(feats_list_train0[i])[0]:
                print('error1')
        feats_list_train = [torch.tensor(ft_train) for ft_train in feats_list_train0]
        label_list_train = [torch.tensor(lb_train) for lb_train in label_list_train0]
        edge_list_train = [torch.tensor(eg_train) for eg_train in edge_list_train0]

    print('***', len(label_list_train))

    file_name = args.dataset.lower() + '-test_feats_label_edge_list'
    file_path = os.path.join(data_dir, file_name)
    # feature_set = set()
    with open(file_path, 'rb') as f:
        feats_list_test0, label_list_test0, edge_list_test0 = pickle.load(f)
        for i in range(np.shape(label_list_test0)[0]):

            if np.max(edge_list_test0[i]) >= np.shape(feats_list_test0[i])[0]:
                print('error2')

        feats_list_test = [torch.tensor(ft_test) for ft_test in feats_list_test0]
        label_list_test = [torch.tensor(lb_test) for lb_test in label_list_test0]
        edge_list_test = [torch.tensor(eg_test) for eg_test in edge_list_test0]

    graph_global_test=[]
    print(len(label_list_test))

    if args.data_group == 'MUTAG':
        for i in range(len(label_list_test[int(len(dataset) * args.split):])):
            edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
            x = torch.tensor(feats_list_test[i], dtype=torch.float)
            y = torch.tensor(label_list_test[i], dtype=torch.long)
            pyg_graph = Data(x=x, edge_index=edge_index,y=y)

            graph_global_test.append(pyg_graph)
            print(len(graph_global_test))

    else:
        for i in range(len(label_list_test)):
            edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
            x = torch.tensor(feats_list_test[i], dtype=torch.float)
            y = torch.tensor(label_list_test[i], dtype=torch.long)
            pyg_graph = Data(x=x, edge_index=edge_index, y=y)

            graph_global_test.append(pyg_graph)
            # print(len(graph_global_test))

    num_node_features =graph_global_test[0].num_node_features

    # Train Dataset split to Clients

    graphs_train = []
    for jj in range(len(label_list_train)):

        edge_index = torch.tensor(edge_list_train[jj], dtype=torch.long)
        x = torch.tensor(feats_list_train[jj], dtype=torch.float)
        y = torch.tensor(label_list_train[jj], dtype=torch.long)
        pyg_graph = Data(x=x, edge_index=edge_index,y=y)

        # graph_global_test.append(pyg_graph)
        graphs_train.append(pyg_graph)

    # exit()
    startup = 0
    Client_list = []
    division = int(len(label_list_train)  * args.split*0.7 / args.clients)
    # print(division)

    division_val = int(len(label_list_train) * args.split*0.2/ args.clients)
    division_test = int(len(label_list_train) * args.split*0.1/ args.clients)


    splitedData = {}
    df = pd.DataFrame()

    graphs_train_=cp.deepcopy(graphs_train)

    for i in range(args.clients):

        ds = f'{i}-{args.data_group}'

        client_data_train=graphs_train_[startup:division + startup]
        client_data_val = graphs_train_[division + startup:division + startup + division_val]


        client_data_test = graphs_train_[division + startup+division_val:division + startup+division_val+division_test]

        startup=division + startup+division_val+division_test

        for j in range(np.shape(x_gen[i])[0]):
            client_data_train.append(g_gen_pyg[i][j])

            # client_data_x.append(x_gen[i][j])
            # client_data_y.append(y_gen[i][j])
            # client_data_edge.append(edge_list_gen[i][j])

        graphs_train=client_data_train
        graphs_val=client_data_val
        graphs_test=client_data_test
        # graph_global_test=graph_global_test

        graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
        graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
        graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
        graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init)
        # exit()
        # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)

        ds_train=graphs_train
        ds_val= graphs_val
        ds_test=graphs_test

        dataloader_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True)
        dataloader_val = DataLoader(ds_val, batch_size=args.batch_size, shuffle=True)
        dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

        num_graph_labels = get_numGraphLabels(ds_train)
        splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                           num_node_features, num_graph_labels, len(ds_train))
        df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)
        # ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed)
        # print(graph_global_test)
    dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})

    return splitedData, df, dataloader_global_test