import random from random import choices import numpy as np import pandas as pd import torch from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.transforms import OneHotDegree from models_ import GIN, serverGIN, GIN_dc, serverGIN_dc # from server import Server from server_ import Server from client_ours import Client_GC from utils import get_maxDegree, get_stats, split_data, get_numGraphLabels, init_structure_encoding from scipy.special import rel_entr import scipy from torch_geometric.utils import erdos_renyi_graph, degree import itertools from torch_geometric.utils.convert import from_networkx import os import pickle import networkx as nx from torch_geometric.data import Data import copy as cp def data_oneDS_ours(args,dataset,ft_dim): # print(len(dataset)) division = int(len(dataset) * args.split / args.clients) # print(division) # division=int(division/4) division = int(division / 4) # print(division) if args.alg == 'CeFGC*': if args.data_group=='MUTAG': res='./data/mutag_' graph_s_gen_dir0_0 = res+"/62-28-0-0-3/sample/sample_data_0_0/" graph_s_gen_dir0_1 = res+"/62-28-0-0-3/sample/sample_data_0_1/" graph_s_gen_dir1_0 = res+"/62-28-1-0-3/sample/sample_data_1_0/" graph_s_gen_dir1_1 = res+"/62-28-1-0-3/sample/sample_data_1_1/" graph_s_gen_dir2_0 = res+"/62-26-2-0-3/sample/sample_data_2_0/" graph_s_gen_dir2_1 = res+"/62-26-2-0-3/sample/sample_data_2_1/" graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0] graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1] num_data_classes=2 graph_s_gen_dirs_list=[graph_s_gen_dirs0,graph_s_gen_dirs1] elif args.data_group=='IMDB-BINARY': res = './data/imdb-binary_' graph_s_gen_dir0_0 = res+"/266-72-0-0-3/sample/sample_data_0_0/" graph_s_gen_dir0_1 = res+"/266-72-0-0-3/sample/sample_data_0_1/" graph_s_gen_dir1_0 = res+"/266-65-1-0-3/sample/sample_data_1_0/" graph_s_gen_dir1_1 = res+"/266-65-1-0-3/sample/sample_data_1_1/" graph_s_gen_dir2_0 = res+"/266-56-2-0-3/sample/sample_data_2_0/" graph_s_gen_dir2_1 = res+"/266-56-2-0-3/sample/sample_data_2_1/" graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0] graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1] num_data_classes = 2 graph_s_gen_dirs_list=[graph_s_gen_dirs0,graph_s_gen_dirs1] elif args.alg == 'CeFGC': if args.data_group == 'MUTAG': res='./data/mutag' graph_s_gen_dir0_0 = res+"/62-28-0-0/sample/sample_data/" graph_s_gen_dir0_1 = res+"/62-21-0-1/sample/sample_data/" graph_s_gen_dir1_0 = res+"/62-28-1-0/sample/sample_data/" graph_s_gen_dir1_1 = res+"/62-24-1-1/sample/sample_data/" graph_s_gen_dir2_0 = res+"/62-26-2-0/sample/sample_data/" graph_s_gen_dir2_1 = res+"/62-22-2-1/sample/sample_data/" graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0] graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1] num_data_classes = 2 graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1] elif args.data_group == 'IMDB-BINARY': res = './data/imdb-binary' graph_s_gen_dir0_0 = res+"266-49-0-0-3/sample/sample_data/" graph_s_gen_dir0_1 = res+"266-72-0-1-3/sample/sample_data/" graph_s_gen_dir1_0 = res+"266-59-1-0-3/sample/sample_data/" graph_s_gen_dir1_1 = res+"266-65-1-1-3/sample/sample_data/" graph_s_gen_dir2_0 = res+"266-55-2-0-3/sample/sample_data/" graph_s_gen_dir2_1 = res+"266-56-2-1-3/sample/sample_data/" graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0] graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1] num_data_classes = 2 graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1] # elif args.alg == 'CeFGC*+': # if args.data_group == 'MUTAG': # res = './data/mutag' # graph_s_gen_dir0_0 = res + "/62-28-0-0-3/sample/sample_data_0_0/" # graph_s_gen_dir0_1 = res + "/62-28-0-0-3/sample/sample_data_0_1/" # graph_s_gen_dir1_0 = res + "/62-28-1-0-3/sample/sample_data_1_0/" # graph_s_gen_dir1_1 = res + "/62-28-1-0-3/sample/sample_data_1_1/" # graph_s_gen_dir2_0 = res + "/62-26-2-0-3/sample/sample_data_2_0/" # graph_s_gen_dir2_1 = res + "/62-26-2-0-3/sample/sample_data_2_1/" # # graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0] # graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1] # num_data_classes = 2 # graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1] # # # elif args.data_group == 'IMDB-BINARY': # res = './data/imdb-binary' # graph_s_gen_dir0_0 = res + "/266-72-0-0-3/sample/sample_data_0_0/" # graph_s_gen_dir0_1 = res + "/266-72-0-0-3/sample/sample_data_0_1/" # graph_s_gen_dir1_0 = res + "/266-65-1-0-3/sample/sample_data_1_0/" # graph_s_gen_dir1_1 = res + "/266-65-1-0-3/sample/sample_data_1_1/" # graph_s_gen_dir2_0 = res + "/266-56-2-0-3/sample/sample_data_2_0/" # graph_s_gen_dir2_1 = res + "/266-56-2-0-3/sample/sample_data_2_1/" # graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0] # graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1] # num_data_classes = 2 # graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1] # # elif args.alg == 'CeFGC+': # if args.data_group == 'MUTAG': # res = './data/mutag' # graph_s_gen_dir0_0 = res + "/62-28-0-0/sample/sample_data/" # graph_s_gen_dir0_1 = res + "/62-21-0-1/sample/sample_data/" # graph_s_gen_dir1_0 = res + "/62-28-1-0/sample/sample_data/" # graph_s_gen_dir1_1 = res + "/62-24-1-1/sample/sample_data/" # graph_s_gen_dir2_0 = res + "/62-26-2-0/sample/sample_data/" # graph_s_gen_dir2_1 = res + "/62-22-2-1/sample/sample_data/" # # graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0] # graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1] # num_data_classes = 2 # graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1] # # # elif args.data_group == 'IMDB-BINARY': # res = './data/imdb-binary' # graph_s_gen_dir0_0 = res + "266-49-0-0-3/sample/sample_data/" # graph_s_gen_dir0_1 = res + "266-72-0-1-3/sample/sample_data/" # graph_s_gen_dir1_0 = res + "266-59-1-0-3/sample/sample_data/" # graph_s_gen_dir1_1 = res + "266-65-1-1-3/sample/sample_data/" # graph_s_gen_dir2_0 = res + "266-55-2-0-3/sample/sample_data/" # graph_s_gen_dir2_1 = res + "266-56-2-1-3/sample/sample_data/" # graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0] # graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1] # num_data_classes = 2 # graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1] graph_s_gen_list_dic={} edge_list_gen_dic={} num_gen_dic={} x_list_dic={} nd_gen_dic={} for iclass in range(num_data_classes): # graph_s_gen_list0 = [] graph_s_gen_list_dic[iclass] = [] # graph_s_gen_list1 = [] # edge_list_gen0=[] # edge_list_gen1 = [] edge_list_gen_dic[iclass] = [] num_gen_dic[iclass]=[] # num_gen0 = [] # num_gen1 = [] nd_gen_dic[iclass] = [] x_list_dic[iclass]=[] # x_list0 = [] # x_list1 = [] random.seed(args.seed) for graph_s_gen_dir0 in graph_s_gen_dirs_list[iclass]: graph_s_gen_dir = graph_s_gen_dir0 files = os.listdir(graph_s_gen_dir) # print(files) for file in files: graph_s_gen_dir_f = os.path.join(graph_s_gen_dir, file) f2 = open(graph_s_gen_dir_f, 'rb') graph_s_gen = pickle.load(f2, encoding='latin1') random.seed(args.seed) # graph_s_gen = random.sample(graph_s_gen, division) # # print(len(graph_s_gen)) graph_s_gen_list_dic[iclass].append(graph_s_gen) g_edges_list = [] nd_gen_num = [] x_list = [] gen_num = 0 for g_ in graph_s_gen: g_edges = g_.edges() g_edges = np.array([[nd1, nd2] for (nd1, nd2) in g_edges]).T if g_.number_of_nodes() == 1: continue if (np.max(g_edges)) > g_.number_of_nodes() or (np.max(g_edges)) == g_.number_of_nodes(): num_tmp = np.max(g_edges) + 1 else: num_tmp = g_.number_of_nodes() g_edges_list.append(g_edges) nd_gen_num.append(num_tmp) x_list.append(np.random.randint(0, 2, (num_tmp, ft_dim))) gen_num += 1 if gen_num >= division: # print('****') break edge_list_gen_dic[iclass].append(g_edges_list) num_gen_dic[iclass].append(gen_num) nd_gen_dic[iclass].append(nd_gen_num) x_list_dic[iclass].append(x_list) g_gen = {} y_gen = {} x_gen = {} edge_list_gen = {} g_gen_pyg = {} for clt1 in range(args.clients): g_gen[clt1] = [] g_gen_pyg[clt1] = [] edge_list_gen[clt1] = [] x_gen[clt1] = [] y_gen[clt1] = [] num_gen = 0 for clt2 in range(args.clients): if clt2 != clt1: for jclass in range(num_data_classes): g_gen[clt1].append(list(graph_s_gen_list_dic[jclass][clt2])) # g_gen[clt1].append(list(graph_s_gen_list1[clt2])) edge_list_gen[clt1].append(list(edge_list_gen_dic[jclass][clt2])) # edge_list_gen[clt1].append(list(edge_list_gen1[clt2])) y_gen[clt1].append([jclass] * num_gen_dic[jclass][clt2]) # y_gen[clt1].append([1] * num_gen1[clt2]) x_gen[clt1].append(list(x_list_dic[jclass][clt2])) # x_gen[clt1].append(list(x_list1[clt2])) num_gen += num_gen_dic[jclass][clt2] # num_gen += num_gen1[clt2] # print(list(itertools.chain.from_iterable(edge_list_gen[clt1]))) edge_list_gen[clt1] = list(itertools.chain.from_iterable(edge_list_gen[clt1])) y_gen[clt1] = list(itertools.chain.from_iterable(y_gen[clt1])) x_gen[clt1] = list(itertools.chain.from_iterable(x_gen[clt1])) idxs = list(range(num_gen)) random.shuffle(idxs) tmp_edge = [] tmp_y = [] tmp_x = [] tmp_g=[] # print(np.shape(np.array(y_gen[clt1])),np.shape(np.array(x_gen[clt1]))) # print(edge_list_gen[clt1][0]) for idx in idxs: tmp_edge.append(torch.tensor(edge_list_gen[clt1][idx])) tmp_y.append(torch.tensor(y_gen[clt1][idx])) tmp_x.append(torch.tensor(x_gen[clt1][idx])) edge_index = torch.tensor(edge_list_gen[clt1][idx], dtype=torch.long) x = torch.tensor(x_gen[clt1][idx], dtype=torch.float) y = torch.tensor(y_gen[clt1][idx], dtype=torch.long) pyg_graph = Data(x=x, edge_index=edge_index, y=y) tmp_g.append(pyg_graph) edge_list_gen[clt1] = tmp_edge y_gen[clt1] = tmp_y # print(y_gen[clt1]) x_gen[clt1] = tmp_x # print('***',x_gen[clt1]) # exit() g_gen_pyg[clt1]=tmp_g torch.manual_seed(12345) dataset = dataset.shuffle() print('!!!', len(dataset)) data_dir = '/data/' file_name = args.dataset.lower() + '-train_feats_label_edge_list' file_path = os.path.join(data_dir, file_name) # feature_set = set() with open(file_path, 'rb') as f: feats_list_train0, label_list_train0, edge_list_train0 = pickle.load(f) for i in range(np.shape(label_list_train0)[0]): if np.max(edge_list_train0[i]) >= np.shape(feats_list_train0[i])[0]: print('error1') feats_list_train = [torch.tensor(ft_train) for ft_train in feats_list_train0] label_list_train = [torch.tensor(lb_train) for lb_train in label_list_train0] edge_list_train = [torch.tensor(eg_train) for eg_train in edge_list_train0] print('***', len(label_list_train)) file_name = args.dataset.lower() + '-test_feats_label_edge_list' file_path = os.path.join(data_dir, file_name) # feature_set = set() with open(file_path, 'rb') as f: feats_list_test0, label_list_test0, edge_list_test0 = pickle.load(f) for i in range(np.shape(label_list_test0)[0]): if np.max(edge_list_test0[i]) >= np.shape(feats_list_test0[i])[0]: print('error2') feats_list_test = [torch.tensor(ft_test) for ft_test in feats_list_test0] label_list_test = [torch.tensor(lb_test) for lb_test in label_list_test0] edge_list_test = [torch.tensor(eg_test) for eg_test in edge_list_test0] graph_global_test=[] print(len(label_list_test)) if args.data_group == 'MUTAG': for i in range(len(label_list_test[int(len(dataset) * args.split):])): edge_index = torch.tensor(edge_list_test[i], dtype=torch.long) x = torch.tensor(feats_list_test[i], dtype=torch.float) y = torch.tensor(label_list_test[i], dtype=torch.long) pyg_graph = Data(x=x, edge_index=edge_index,y=y) graph_global_test.append(pyg_graph) print(len(graph_global_test)) else: for i in range(len(label_list_test)): edge_index = torch.tensor(edge_list_test[i], dtype=torch.long) x = torch.tensor(feats_list_test[i], dtype=torch.float) y = torch.tensor(label_list_test[i], dtype=torch.long) pyg_graph = Data(x=x, edge_index=edge_index, y=y) graph_global_test.append(pyg_graph) # print(len(graph_global_test)) num_node_features =graph_global_test[0].num_node_features # Train Dataset split to Clients graphs_train = [] for jj in range(len(label_list_train)): edge_index = torch.tensor(edge_list_train[jj], dtype=torch.long) x = torch.tensor(feats_list_train[jj], dtype=torch.float) y = torch.tensor(label_list_train[jj], dtype=torch.long) pyg_graph = Data(x=x, edge_index=edge_index,y=y) # graph_global_test.append(pyg_graph) graphs_train.append(pyg_graph) # exit() startup = 0 Client_list = [] division = int(len(label_list_train) * args.split*0.7 / args.clients) # print(division) division_val = int(len(label_list_train) * args.split*0.2/ args.clients) division_test = int(len(label_list_train) * args.split*0.1/ args.clients) splitedData = {} df = pd.DataFrame() graphs_train_=cp.deepcopy(graphs_train) for i in range(args.clients): ds = f'{i}-{args.data_group}' client_data_train=graphs_train_[startup:division + startup] client_data_val = graphs_train_[division + startup:division + startup + division_val] client_data_test = graphs_train_[division + startup+division_val:division + startup+division_val+division_test] startup=division + startup+division_val+division_test for j in range(np.shape(x_gen[i])[0]): client_data_train.append(g_gen_pyg[i][j]) # client_data_x.append(x_gen[i][j]) # client_data_y.append(y_gen[i][j]) # client_data_edge.append(edge_list_gen[i][j]) graphs_train=client_data_train graphs_val=client_data_val graphs_test=client_data_test # graph_global_test=graph_global_test graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init) graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init) graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init) graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init) # exit() # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed) ds_train=graphs_train ds_val= graphs_val ds_test=graphs_test dataloader_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True) dataloader_val = DataLoader(ds_val, batch_size=args.batch_size, shuffle=True) dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True) num_graph_labels = get_numGraphLabels(ds_train) splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test}, num_node_features, num_graph_labels, len(ds_train)) df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test) # ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed) # print(graph_global_test) dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)}) return splitedData, df, dataloader_global_test