import random
from random import choices
import numpy as np
import pandas as pd
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import OneHotDegree
from models_ import GIN, serverGIN, GIN_dc, serverGIN_dc
# from server import Server
from server_ import Server
# from client import Client_GC
from client_ours import Client_GC
from utils import get_maxDegree, get_stats, split_data, get_numGraphLabels, init_structure_encoding
from scipy.special import rel_entr
import scipy
from torch_geometric.utils import erdos_renyi_graph, degree
import itertools
from data_oneDS_ours import *
# from data_oneDS_ours_6clients import *
# from data_oneDS_ours_9clients import *
# from data_oneDS_ours_12clients import *
# from data_oneDS_ours_15clients import *
def _randChunk(graphs, num_client, overlap, seed=None):
random.seed(seed)
np.random.seed(seed)
totalNum = len(graphs)
minSize = min(50, int(totalNum/num_client))
graphs_chunks = []
if not overlap:
for i in range(num_client):
graphs_chunks.append(graphs[i*minSize:(i+1)*minSize])
for g in graphs[num_client*minSize:]:
idx_chunk = np.random.randint(low=0, high=num_client, size=1)[0]
graphs_chunks[idx_chunk].append(g)
else:
sizes = np.random.randint(low=50, high=150, size=num_client)
for s in sizes:
graphs_chunks.append(choices(graphs, k=s))
return graphs_chunks
def _randChunk_(graphs, num_client, overlap, seed=None):
random.seed(seed)
np.random.seed(seed)
# print(set(list(graphs.y)))
totalNum = len(graphs)
# totalNum_=int(totalNum*0.9)
# minSize = min(50, int(totalNum/num_client))
minSize = int(0.8*totalNum/num_client)
graphs_chunks = []
if not overlap:
for i in range(num_client):
graphs_chunks.append(graphs[i*minSize:(i+1)*minSize])
for g in graphs[num_client*minSize:]:
idx_chunk = np.random.randint(low=0, high=num_client, size=1)[0]
graphs_chunks[idx_chunk].append(g)
graph_global_test=graphs[num_client*minSize:]
else:
sizes = np.random.randint(low=50, high=150, size=num_client)
for s in sizes:
graphs_chunks.append(choices(graphs, k=s))
return graphs_chunks,graph_global_test
def prepareData_oneDS(datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False):
if data == "COLLAB":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
elif data == "IMDB-BINARY":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
elif data == "IMDB-MULTI":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
else:
tudataset = TUDataset(f"{datapath}/TUDataset", data)
if convert_x:
maxdegree = get_maxDegree(tudataset)
tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
graphs = [x for x in tudataset]
print(" **", data, len(graphs))
graphs_chunks = _randChunk(graphs, num_client, overlap, seed=seed)
splitedData = {}
df = pd.DataFrame()
num_node_features = graphs[0].num_node_features
for idx, chunks in enumerate(graphs_chunks):
ds = f'{idx}-{data}'
ds_tvt = chunks
ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
dataloader_test = DataLoader(ds_test, batch_size=batchSize, shuffle=True)
num_graph_labels = get_numGraphLabels(ds_train)
splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
num_node_features, num_graph_labels, len(ds_train))
df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)
return splitedData, df
def prepareData_multiDS(args, datapath, group='chem', batchSize=32, seed=None):
assert group in ['chem', "biochem", 'biochemsn', "biosncv"]
if group == 'chem':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
elif group == 'biochem':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules
"ENZYMES", "DD", "PROTEINS"] # bioinformatics
elif group == 'biochemsn':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules
"ENZYMES", "DD", "PROTEINS", # bioinformatics
"COLLAB", "IMDB-BINARY", "IMDB-MULTI"] # social networks
elif group == 'biosncv':
datasets = ["ENZYMES", "DD", "PROTEINS", # bioinformatics
"COLLAB", "IMDB-BINARY", "IMDB-MULTI", # social networks
"Letter-high", "Letter-low", "Letter-med"] # computer vision
splitedData = {}
df = pd.DataFrame()
for data in datasets:
if data == "COLLAB":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
elif data == "IMDB-BINARY":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
elif data == "IMDB-MULTI":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
elif "Letter" in data:
tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True)
else:
tudataset = TUDataset(f"{datapath}/TUDataset", data)
graphs = [x for x in tudataset]
print(" **", data, len(graphs))
graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed)
graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)
graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True)
dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True)
dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True)
num_node_features = graphs[0].num_node_features
num_graph_labels = get_numGraphLabels(graphs_train)
splitedData[data] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
num_node_features, num_graph_labels, len(graphs_train))
df = get_stats(df, data, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test)
return splitedData, df
def prepareData_multiDS_multi(args, datapath, group='small', batchSize=32, nc_per_ds=1, seed=None):
assert group in ['chem', "biochem", 'biochemsn', "biosncv"]
if group == 'chem':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
elif group == 'biochem':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules
"ENZYMES", "DD", "PROTEINS"] # bioinformatics
elif group == 'biochemsn':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules
"ENZYMES", "DD", "PROTEINS", # bioinformatics
"COLLAB", "IMDB-BINARY", "IMDB-MULTI"] # social networks
# "Letter-low", "Letter-med"] # computer vision
elif group == 'biosncv':
datasets = ["ENZYMES", "DD", "PROTEINS", # bioinformatics
"COLLAB", "IMDB-BINARY", "IMDB-MULTI", # social networks
"Letter-high", "Letter-low", "Letter-med"] # computer vision
splitedData = {}
df = pd.DataFrame()
for data in datasets:
if data == "COLLAB":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
elif data == "IMDB-BINARY":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
elif data == "IMDB-MULTI":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
elif "Letter" in data:
tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True)
else:
tudataset = TUDataset(f"{datapath}/TUDataset", data)
graphs = [x for x in tudataset]
print(" **", data, len(graphs))
num_node_features = graphs[0].num_node_features
graphs_chunks = _randChunk(graphs, nc_per_ds, overlap=False, seed=seed)
for idx, chunks in enumerate(graphs_chunks):
ds = f'{idx}-{data}'
ds_tvt = chunks
graphs_train, graphs_valtest = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)
graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True)
dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True)
dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True)
num_graph_labels = get_numGraphLabels(graphs_train)
splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
num_node_features, num_graph_labels, len(graphs_train))
df = get_stats(df, ds, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test)
return splitedData, df
def js_diver(P,Q):
M=P+Q
return 0.5*scipy.stats.entropy(P,M,base=2)+0.5*scipy.stats.entropy(Q,M,base=2)
def setup_devices(splitedData, args):
idx_clients = {}
clients = []
for idx, ds in enumerate(splitedData.keys()):
idx_clients[idx] = ds
dataloaders, num_node_features, num_graph_labels, train_size = splitedData[ds]
if args.alg == 'fedstar':
cmodel_gc = GIN_dc(num_node_features, args.n_se, args.hidden, num_graph_labels, args.nlayer, args.dropout)
else:
cmodel_gc = GIN(num_node_features, args.hidden, num_graph_labels, args.nlayer, args.dropout)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cmodel_gc.parameters()), lr=args.lr, weight_decay=args.weight_decay)
clients.append(Client_GC(cmodel_gc, idx, ds, train_size, dataloaders, optimizer, args))
if args.alg == 'fedstar':
smodel = serverGIN_dc(n_se=args.n_se, nlayer=args.nlayer, nhid=args.hidden)
else:
smodel = serverGIN(nlayer=args.nlayer, nhid=args.hidden)
server = Server(smodel, args.device)
return clients, server, idx_clients
def setup_devices_(splitedData,global_test_data, args):
idx_clients = {}
clients = []
for idx, ds in enumerate(splitedData.keys()):
idx_clients[idx] = ds
dataloaders, num_node_features, num_graph_labels, train_size = splitedData[ds]
if args.alg == 'fedstar' or args.alg == 'ours1' or args.alg == 'ours2'or args.alg == 'agg_mean'or args.alg == 'agg_cluster2':
cmodel_gc = GIN_dc(num_node_features, args.n_se, args.hidden, num_graph_labels, args.nlayer, args.dropout)
else:
cmodel_gc = GIN(num_node_features, args.hidden, num_graph_labels, args.nlayer, args.dropout)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cmodel_gc.parameters()), lr=args.lr, weight_decay=args.weight_decay)
clients.append(Client_GC(cmodel_gc, idx, ds, train_size, dataloaders, optimizer, args))
if args.alg == 'fedstar' or args.alg == 'ours1' or args.alg == 'ours2'or args.alg == 'agg_mean' or args.alg == 'agg_cluster2':
smodel = serverGIN_dc(num_node_features,n_se=args.n_se, nlayer=args.nlayer, nclass=num_graph_labels, nhid=args.hidden,dropout=args.dropout)
else:
smodel = serverGIN(nlayer=args.nlayer, nhid=args.hidden)
server = Server(smodel,global_test_data, args.device)
return clients, server, idx_clients
def prepareData_oneDS_(args,datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False):
if data == "COLLAB":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
elif data == "IMDB-BINARY":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
elif data == "IMDB-MULTI":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
else:
tudataset = TUDataset(f"{datapath}/TUDataset", data)
if convert_x:
maxdegree = get_maxDegree(tudataset)
tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
graphs = [x for x in tudataset]
print(" **", data, len(graphs))
# ys=[]
# for x in tudataset:
# ys.append(x.y)
# print(set(list(np.array(ys))))
# print(len(np.where(np.array(ys)==0)[0]),len(np.where(np.array(ys)==1)[0]))
# exit()
if data == 'ENZYMES':
random.shuffle(graphs)
graphs_chunks,graph_global_test = _randChunk_(graphs, num_client, overlap, seed=seed)
splitedData = {}
df = pd.DataFrame()
num_node_features = graphs[0].num_node_features
for idx, chunks in enumerate(graphs_chunks):
ds = f'{idx}-{data}'
ds_tvt = chunks
# print(ds_tvt)
if data=='ENZYMES':
ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
ds_val=ds_train
ds_test=ds_train
graphs_train=ds_train
graphs_val=ds_val
graphs_test=ds_test
else:
# ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
# ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
graphs_train, graphs_valtest = split_data(ds_tvt, test=0.2, shuffle=True, seed=seed)
graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)
# print(ds_vt)
graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init)
# exit()
# ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
ds_train=graphs_train
ds_val= graphs_val
ds_test=graphs_test
dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)
num_graph_labels = get_numGraphLabels(ds_train)
splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
num_node_features, num_graph_labels, len(ds_train))
df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)
# ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed)
# print(graph_global_test)
dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})
return splitedData, df,dataloader_global_test
def prepareData_oneDS_ours_(args,datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False):
# Load Dataset
if data == "COLLAB":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
elif data == "IMDB-BINARY":
# tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
dataset = TUDataset(f"{datapath}/TUDataset", data)
maxdegree = get_maxDegree(dataset)
tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
ft_dim = maxdegree + 1
elif data == "IMDB-MULTI":
# tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
dataset = TUDataset(f"{datapath}/TUDataset", data)
maxdegree = get_maxDegree(dataset)
tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
ft_dim = maxdegree + 1
elif args.dataset.lower() == 'enzymes':
tudataset = TUDataset(f"{datapath}/TUDataset", data)
ft_dim = 3
elif args.dataset.lower() == 'mutag':
tudataset = TUDataset(f"{datapath}/TUDataset", data)
ft_dim = 7
elif args.dataset.lower() == 'proteins':
tudataset = TUDataset(f"{datapath}/TUDataset", data)
ft_dim = 3
else:
tudataset = TUDataset(f"{datapath}/TUDataset", data)
if convert_x:
maxdegree = get_maxDegree(tudataset)
tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
graphs = [x for x in tudataset]
# print(" **", data, len(graphs))
graphs = [x for x in tudataset]
splitedData, df, dataloader_global_test=data_oneDS_ours(args,tudataset,ft_dim)
return splitedData, df,dataloader_global_test
def prepareData_multiDS_(args, datapath, group='chem', batchSize=32, seed=None):
assert group in ['chem', "biochem", 'biochemsn', "biosncv"]
if group == 'chem':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
elif group == 'biochem':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules
"ENZYMES", "DD", "PROTEINS"] # bioinformatics
elif group == 'biochemsn':
datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules
"ENZYMES", "DD", "PROTEINS", # bioinformatics
"COLLAB", "IMDB-BINARY", "IMDB-MULTI"] # social networks
elif group == 'biosncv':
datasets = ["ENZYMES", "DD", "PROTEINS", # bioinformatics
"COLLAB", "IMDB-BINARY", "IMDB-MULTI", # social networks
"Letter-high", "Letter-low", "Letter-med"] # computer vision
splitedData = {}
df = pd.DataFrame()
for data in datasets:
if data == "COLLAB":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
elif data == "IMDB-BINARY":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
elif data == "IMDB-MULTI":
tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
elif "Letter" in data:
tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True)
else:
tudataset = TUDataset(f"{datapath}/TUDataset", data)
graphs = [x for x in tudataset]
print(" **", data, len(graphs))
graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed)
graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)
graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True)
dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True)
dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True)
num_node_features = graphs[0].num_node_features
num_graph_labels = get_numGraphLabels(graphs_train)
splitedData[data] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
num_node_features, num_graph_labels, len(graphs_train))
df = get_stats(df, data, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test)
return splitedData, df