import random
from random import choices
import numpy as np
import pandas as pd
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import OneHotDegree
from models_ import GIN, serverGIN, GIN_dc, serverGIN_dc
# from server import Server
from server_ import Server
from client_ours import Client_GC
from utils import get_maxDegree, get_stats, split_data, get_numGraphLabels, init_structure_encoding
from scipy.special import rel_entr
import scipy
from torch_geometric.utils import erdos_renyi_graph, degree
import itertools
from torch_geometric.utils.convert import from_networkx
import os
import pickle
import networkx as nx
from torch_geometric.data import Data
import copy as cp
def data_oneDS_ours(args,dataset,ft_dim):
# print(len(dataset))
division = int(len(dataset) * args.split / args.clients)
# print(division)
# division=int(division/4)
division = int(division / 4)
# print(division)
if args.alg == 'CeFGC*':
if args.data_group=='MUTAG':
res='./data/mutag_'
graph_s_gen_dir0_0 = res+"/62-28-0-0-3/sample/sample_data_0_0/"
graph_s_gen_dir0_1 = res+"/62-28-0-0-3/sample/sample_data_0_1/"
graph_s_gen_dir1_0 = res+"/62-28-1-0-3/sample/sample_data_1_0/"
graph_s_gen_dir1_1 = res+"/62-28-1-0-3/sample/sample_data_1_1/"
graph_s_gen_dir2_0 = res+"/62-26-2-0-3/sample/sample_data_2_0/"
graph_s_gen_dir2_1 = res+"/62-26-2-0-3/sample/sample_data_2_1/"
graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
num_data_classes=2
graph_s_gen_dirs_list=[graph_s_gen_dirs0,graph_s_gen_dirs1]
elif args.data_group=='IMDB-BINARY':
res = './data/imdb-binary_'
graph_s_gen_dir0_0 = res+"/266-72-0-0-3/sample/sample_data_0_0/"
graph_s_gen_dir0_1 = res+"/266-72-0-0-3/sample/sample_data_0_1/"
graph_s_gen_dir1_0 = res+"/266-65-1-0-3/sample/sample_data_1_0/"
graph_s_gen_dir1_1 = res+"/266-65-1-0-3/sample/sample_data_1_1/"
graph_s_gen_dir2_0 = res+"/266-56-2-0-3/sample/sample_data_2_0/"
graph_s_gen_dir2_1 = res+"/266-56-2-0-3/sample/sample_data_2_1/"
graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
num_data_classes = 2
graph_s_gen_dirs_list=[graph_s_gen_dirs0,graph_s_gen_dirs1]
elif args.alg == 'CeFGC':
if args.data_group == 'MUTAG':
res='./data/mutag'
graph_s_gen_dir0_0 = res+"/62-28-0-0/sample/sample_data/"
graph_s_gen_dir0_1 = res+"/62-21-0-1/sample/sample_data/"
graph_s_gen_dir1_0 = res+"/62-28-1-0/sample/sample_data/"
graph_s_gen_dir1_1 = res+"/62-24-1-1/sample/sample_data/"
graph_s_gen_dir2_0 = res+"/62-26-2-0/sample/sample_data/"
graph_s_gen_dir2_1 = res+"/62-22-2-1/sample/sample_data/"
graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
num_data_classes = 2
graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
elif args.data_group == 'IMDB-BINARY':
res = './data/imdb-binary'
graph_s_gen_dir0_0 = res+"266-49-0-0-3/sample/sample_data/"
graph_s_gen_dir0_1 = res+"266-72-0-1-3/sample/sample_data/"
graph_s_gen_dir1_0 = res+"266-59-1-0-3/sample/sample_data/"
graph_s_gen_dir1_1 = res+"266-65-1-1-3/sample/sample_data/"
graph_s_gen_dir2_0 = res+"266-55-2-0-3/sample/sample_data/"
graph_s_gen_dir2_1 = res+"266-56-2-1-3/sample/sample_data/"
graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
num_data_classes = 2
graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
# elif args.alg == 'CeFGC*+':
# if args.data_group == 'MUTAG':
# res = './data/mutag'
# graph_s_gen_dir0_0 = res + "/62-28-0-0-3/sample/sample_data_0_0/"
# graph_s_gen_dir0_1 = res + "/62-28-0-0-3/sample/sample_data_0_1/"
# graph_s_gen_dir1_0 = res + "/62-28-1-0-3/sample/sample_data_1_0/"
# graph_s_gen_dir1_1 = res + "/62-28-1-0-3/sample/sample_data_1_1/"
# graph_s_gen_dir2_0 = res + "/62-26-2-0-3/sample/sample_data_2_0/"
# graph_s_gen_dir2_1 = res + "/62-26-2-0-3/sample/sample_data_2_1/"
#
# graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
# graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
# num_data_classes = 2
# graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
#
#
# elif args.data_group == 'IMDB-BINARY':
# res = './data/imdb-binary'
# graph_s_gen_dir0_0 = res + "/266-72-0-0-3/sample/sample_data_0_0/"
# graph_s_gen_dir0_1 = res + "/266-72-0-0-3/sample/sample_data_0_1/"
# graph_s_gen_dir1_0 = res + "/266-65-1-0-3/sample/sample_data_1_0/"
# graph_s_gen_dir1_1 = res + "/266-65-1-0-3/sample/sample_data_1_1/"
# graph_s_gen_dir2_0 = res + "/266-56-2-0-3/sample/sample_data_2_0/"
# graph_s_gen_dir2_1 = res + "/266-56-2-0-3/sample/sample_data_2_1/"
# graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
# graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
# num_data_classes = 2
# graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
#
# elif args.alg == 'CeFGC+':
# if args.data_group == 'MUTAG':
# res = './data/mutag'
# graph_s_gen_dir0_0 = res + "/62-28-0-0/sample/sample_data/"
# graph_s_gen_dir0_1 = res + "/62-21-0-1/sample/sample_data/"
# graph_s_gen_dir1_0 = res + "/62-28-1-0/sample/sample_data/"
# graph_s_gen_dir1_1 = res + "/62-24-1-1/sample/sample_data/"
# graph_s_gen_dir2_0 = res + "/62-26-2-0/sample/sample_data/"
# graph_s_gen_dir2_1 = res + "/62-22-2-1/sample/sample_data/"
#
# graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
# graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
# num_data_classes = 2
# graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
#
#
# elif args.data_group == 'IMDB-BINARY':
# res = './data/imdb-binary'
# graph_s_gen_dir0_0 = res + "266-49-0-0-3/sample/sample_data/"
# graph_s_gen_dir0_1 = res + "266-72-0-1-3/sample/sample_data/"
# graph_s_gen_dir1_0 = res + "266-59-1-0-3/sample/sample_data/"
# graph_s_gen_dir1_1 = res + "266-65-1-1-3/sample/sample_data/"
# graph_s_gen_dir2_0 = res + "266-55-2-0-3/sample/sample_data/"
# graph_s_gen_dir2_1 = res + "266-56-2-1-3/sample/sample_data/"
# graph_s_gen_dirs0 = [graph_s_gen_dir0_0, graph_s_gen_dir1_0, graph_s_gen_dir2_0]
# graph_s_gen_dirs1 = [graph_s_gen_dir0_1, graph_s_gen_dir1_1, graph_s_gen_dir2_1]
# num_data_classes = 2
# graph_s_gen_dirs_list = [graph_s_gen_dirs0, graph_s_gen_dirs1]
graph_s_gen_list_dic={}
edge_list_gen_dic={}
num_gen_dic={}
x_list_dic={}
nd_gen_dic={}
for iclass in range(num_data_classes):
# graph_s_gen_list0 = []
graph_s_gen_list_dic[iclass] = []
# graph_s_gen_list1 = []
# edge_list_gen0=[]
# edge_list_gen1 = []
edge_list_gen_dic[iclass] = []
num_gen_dic[iclass]=[]
# num_gen0 = []
# num_gen1 = []
nd_gen_dic[iclass] = []
x_list_dic[iclass]=[]
# x_list0 = []
# x_list1 = []
random.seed(args.seed)
for graph_s_gen_dir0 in graph_s_gen_dirs_list[iclass]:
graph_s_gen_dir = graph_s_gen_dir0
files = os.listdir(graph_s_gen_dir)
# print(files)
for file in files:
graph_s_gen_dir_f = os.path.join(graph_s_gen_dir, file)
f2 = open(graph_s_gen_dir_f, 'rb')
graph_s_gen = pickle.load(f2, encoding='latin1')
random.seed(args.seed)
# graph_s_gen = random.sample(graph_s_gen, division)
#
# print(len(graph_s_gen))
graph_s_gen_list_dic[iclass].append(graph_s_gen)
g_edges_list = []
nd_gen_num = []
x_list = []
gen_num = 0
for g_ in graph_s_gen:
g_edges = g_.edges()
g_edges = np.array([[nd1, nd2] for (nd1, nd2) in g_edges]).T
if g_.number_of_nodes() == 1:
continue
if (np.max(g_edges)) > g_.number_of_nodes() or (np.max(g_edges)) == g_.number_of_nodes():
num_tmp = np.max(g_edges) + 1
else:
num_tmp = g_.number_of_nodes()
g_edges_list.append(g_edges)
nd_gen_num.append(num_tmp)
x_list.append(np.random.randint(0, 2, (num_tmp, ft_dim)))
gen_num += 1
if gen_num >= division:
# print('****')
break
edge_list_gen_dic[iclass].append(g_edges_list)
num_gen_dic[iclass].append(gen_num)
nd_gen_dic[iclass].append(nd_gen_num)
x_list_dic[iclass].append(x_list)
g_gen = {}
y_gen = {}
x_gen = {}
edge_list_gen = {}
g_gen_pyg = {}
for clt1 in range(args.clients):
g_gen[clt1] = []
g_gen_pyg[clt1] = []
edge_list_gen[clt1] = []
x_gen[clt1] = []
y_gen[clt1] = []
num_gen = 0
for clt2 in range(args.clients):
if clt2 != clt1:
for jclass in range(num_data_classes):
g_gen[clt1].append(list(graph_s_gen_list_dic[jclass][clt2]))
# g_gen[clt1].append(list(graph_s_gen_list1[clt2]))
edge_list_gen[clt1].append(list(edge_list_gen_dic[jclass][clt2]))
# edge_list_gen[clt1].append(list(edge_list_gen1[clt2]))
y_gen[clt1].append([jclass] * num_gen_dic[jclass][clt2])
# y_gen[clt1].append([1] * num_gen1[clt2])
x_gen[clt1].append(list(x_list_dic[jclass][clt2]))
# x_gen[clt1].append(list(x_list1[clt2]))
num_gen += num_gen_dic[jclass][clt2]
# num_gen += num_gen1[clt2]
# print(list(itertools.chain.from_iterable(edge_list_gen[clt1])))
edge_list_gen[clt1] = list(itertools.chain.from_iterable(edge_list_gen[clt1]))
y_gen[clt1] = list(itertools.chain.from_iterable(y_gen[clt1]))
x_gen[clt1] = list(itertools.chain.from_iterable(x_gen[clt1]))
idxs = list(range(num_gen))
random.shuffle(idxs)
tmp_edge = []
tmp_y = []
tmp_x = []
tmp_g=[]
# print(np.shape(np.array(y_gen[clt1])),np.shape(np.array(x_gen[clt1])))
# print(edge_list_gen[clt1][0])
for idx in idxs:
tmp_edge.append(torch.tensor(edge_list_gen[clt1][idx]))
tmp_y.append(torch.tensor(y_gen[clt1][idx]))
tmp_x.append(torch.tensor(x_gen[clt1][idx]))
edge_index = torch.tensor(edge_list_gen[clt1][idx], dtype=torch.long)
x = torch.tensor(x_gen[clt1][idx], dtype=torch.float)
y = torch.tensor(y_gen[clt1][idx], dtype=torch.long)
pyg_graph = Data(x=x, edge_index=edge_index, y=y)
tmp_g.append(pyg_graph)
edge_list_gen[clt1] = tmp_edge
y_gen[clt1] = tmp_y
# print(y_gen[clt1])
x_gen[clt1] = tmp_x
# print('***',x_gen[clt1])
# exit()
g_gen_pyg[clt1]=tmp_g
torch.manual_seed(12345)
dataset = dataset.shuffle()
print('!!!', len(dataset))
data_dir = '/data/'
file_name = args.dataset.lower() + '-train_feats_label_edge_list'
file_path = os.path.join(data_dir, file_name)
# feature_set = set()
with open(file_path, 'rb') as f:
feats_list_train0, label_list_train0, edge_list_train0 = pickle.load(f)
for i in range(np.shape(label_list_train0)[0]):
if np.max(edge_list_train0[i]) >= np.shape(feats_list_train0[i])[0]:
print('error1')
feats_list_train = [torch.tensor(ft_train) for ft_train in feats_list_train0]
label_list_train = [torch.tensor(lb_train) for lb_train in label_list_train0]
edge_list_train = [torch.tensor(eg_train) for eg_train in edge_list_train0]
print('***', len(label_list_train))
file_name = args.dataset.lower() + '-test_feats_label_edge_list'
file_path = os.path.join(data_dir, file_name)
# feature_set = set()
with open(file_path, 'rb') as f:
feats_list_test0, label_list_test0, edge_list_test0 = pickle.load(f)
for i in range(np.shape(label_list_test0)[0]):
if np.max(edge_list_test0[i]) >= np.shape(feats_list_test0[i])[0]:
print('error2')
feats_list_test = [torch.tensor(ft_test) for ft_test in feats_list_test0]
label_list_test = [torch.tensor(lb_test) for lb_test in label_list_test0]
edge_list_test = [torch.tensor(eg_test) for eg_test in edge_list_test0]
graph_global_test=[]
print(len(label_list_test))
if args.data_group == 'MUTAG':
for i in range(len(label_list_test[int(len(dataset) * args.split):])):
edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
x = torch.tensor(feats_list_test[i], dtype=torch.float)
y = torch.tensor(label_list_test[i], dtype=torch.long)
pyg_graph = Data(x=x, edge_index=edge_index,y=y)
graph_global_test.append(pyg_graph)
print(len(graph_global_test))
else:
for i in range(len(label_list_test)):
edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
x = torch.tensor(feats_list_test[i], dtype=torch.float)
y = torch.tensor(label_list_test[i], dtype=torch.long)
pyg_graph = Data(x=x, edge_index=edge_index, y=y)
graph_global_test.append(pyg_graph)
# print(len(graph_global_test))
num_node_features =graph_global_test[0].num_node_features
# Train Dataset split to Clients
graphs_train = []
for jj in range(len(label_list_train)):
edge_index = torch.tensor(edge_list_train[jj], dtype=torch.long)
x = torch.tensor(feats_list_train[jj], dtype=torch.float)
y = torch.tensor(label_list_train[jj], dtype=torch.long)
pyg_graph = Data(x=x, edge_index=edge_index,y=y)
# graph_global_test.append(pyg_graph)
graphs_train.append(pyg_graph)
# exit()
startup = 0
Client_list = []
division = int(len(label_list_train) * args.split*0.7 / args.clients)
# print(division)
division_val = int(len(label_list_train) * args.split*0.2/ args.clients)
division_test = int(len(label_list_train) * args.split*0.1/ args.clients)
splitedData = {}
df = pd.DataFrame()
graphs_train_=cp.deepcopy(graphs_train)
for i in range(args.clients):
ds = f'{i}-{args.data_group}'
client_data_train=graphs_train_[startup:division + startup]
client_data_val = graphs_train_[division + startup:division + startup + division_val]
client_data_test = graphs_train_[division + startup+division_val:division + startup+division_val+division_test]
startup=division + startup+division_val+division_test
for j in range(np.shape(x_gen[i])[0]):
client_data_train.append(g_gen_pyg[i][j])
# client_data_x.append(x_gen[i][j])
# client_data_y.append(y_gen[i][j])
# client_data_edge.append(edge_list_gen[i][j])
graphs_train=client_data_train
graphs_val=client_data_val
graphs_test=client_data_test
# graph_global_test=graph_global_test
graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init)
# exit()
# ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
ds_train=graphs_train
ds_val= graphs_val
ds_test=graphs_test
dataloader_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True)
dataloader_val = DataLoader(ds_val, batch_size=args.batch_size, shuffle=True)
dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)
num_graph_labels = get_numGraphLabels(ds_train)
splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
num_node_features, num_graph_labels, len(ds_train))
df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)
# ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed)
# print(graph_global_test)
dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})
return splitedData, df, dataloader_global_test