import random from random import choices import numpy as np import pandas as pd import torch from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.transforms import OneHotDegree from models_ import GIN, serverGIN, GIN_dc, serverGIN_dc # from server import Server from server_ import Server # from client import Client_GC from client_ours import Client_GC from utils import get_maxDegree, get_stats, split_data, get_numGraphLabels, init_structure_encoding from scipy.special import rel_entr import scipy from torch_geometric.utils import erdos_renyi_graph, degree import itertools from data_oneDS_ours import * # from data_oneDS_ours_6clients import * # from data_oneDS_ours_9clients import * # from data_oneDS_ours_12clients import * # from data_oneDS_ours_15clients import * def _randChunk(graphs, num_client, overlap, seed=None): random.seed(seed) np.random.seed(seed) totalNum = len(graphs) minSize = min(50, int(totalNum/num_client)) graphs_chunks = [] if not overlap: for i in range(num_client): graphs_chunks.append(graphs[i*minSize:(i+1)*minSize]) for g in graphs[num_client*minSize:]: idx_chunk = np.random.randint(low=0, high=num_client, size=1)[0] graphs_chunks[idx_chunk].append(g) else: sizes = np.random.randint(low=50, high=150, size=num_client) for s in sizes: graphs_chunks.append(choices(graphs, k=s)) return graphs_chunks def _randChunk_(graphs, num_client, overlap, seed=None): random.seed(seed) np.random.seed(seed) # print(set(list(graphs.y))) totalNum = len(graphs) # totalNum_=int(totalNum*0.9) # minSize = min(50, int(totalNum/num_client)) minSize = int(0.8*totalNum/num_client) graphs_chunks = [] if not overlap: for i in range(num_client): graphs_chunks.append(graphs[i*minSize:(i+1)*minSize]) for g in graphs[num_client*minSize:]: idx_chunk = np.random.randint(low=0, high=num_client, size=1)[0] graphs_chunks[idx_chunk].append(g) graph_global_test=graphs[num_client*minSize:] else: sizes = np.random.randint(low=50, high=150, size=num_client) for s in sizes: graphs_chunks.append(choices(graphs, k=s)) return graphs_chunks,graph_global_test def prepareData_oneDS(datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False): if data == "COLLAB": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False)) elif data == "IMDB-BINARY": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False)) elif data == "IMDB-MULTI": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False)) else: tudataset = TUDataset(f"{datapath}/TUDataset", data) if convert_x: maxdegree = get_maxDegree(tudataset) tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False)) graphs = [x for x in tudataset] print(" **", data, len(graphs)) graphs_chunks = _randChunk(graphs, num_client, overlap, seed=seed) splitedData = {} df = pd.DataFrame() num_node_features = graphs[0].num_node_features for idx, chunks in enumerate(graphs_chunks): ds = f'{idx}-{data}' ds_tvt = chunks ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed) ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed) dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True) dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True) dataloader_test = DataLoader(ds_test, batch_size=batchSize, shuffle=True) num_graph_labels = get_numGraphLabels(ds_train) splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test}, num_node_features, num_graph_labels, len(ds_train)) df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test) return splitedData, df def prepareData_multiDS(args, datapath, group='chem', batchSize=32, seed=None): assert group in ['chem', "biochem", 'biochemsn', "biosncv"] if group == 'chem': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"] elif group == 'biochem': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules "ENZYMES", "DD", "PROTEINS"] # bioinformatics elif group == 'biochemsn': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules "ENZYMES", "DD", "PROTEINS", # bioinformatics "COLLAB", "IMDB-BINARY", "IMDB-MULTI"] # social networks elif group == 'biosncv': datasets = ["ENZYMES", "DD", "PROTEINS", # bioinformatics "COLLAB", "IMDB-BINARY", "IMDB-MULTI", # social networks "Letter-high", "Letter-low", "Letter-med"] # computer vision splitedData = {} df = pd.DataFrame() for data in datasets: if data == "COLLAB": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False)) elif data == "IMDB-BINARY": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False)) elif data == "IMDB-MULTI": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False)) elif "Letter" in data: tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True) else: tudataset = TUDataset(f"{datapath}/TUDataset", data) graphs = [x for x in tudataset] print(" **", data, len(graphs)) graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed) graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed) graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init) graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init) graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init) dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True) dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True) dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True) num_node_features = graphs[0].num_node_features num_graph_labels = get_numGraphLabels(graphs_train) splitedData[data] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test}, num_node_features, num_graph_labels, len(graphs_train)) df = get_stats(df, data, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test) return splitedData, df def prepareData_multiDS_multi(args, datapath, group='small', batchSize=32, nc_per_ds=1, seed=None): assert group in ['chem', "biochem", 'biochemsn', "biosncv"] if group == 'chem': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"] elif group == 'biochem': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules "ENZYMES", "DD", "PROTEINS"] # bioinformatics elif group == 'biochemsn': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules "ENZYMES", "DD", "PROTEINS", # bioinformatics "COLLAB", "IMDB-BINARY", "IMDB-MULTI"] # social networks # "Letter-low", "Letter-med"] # computer vision elif group == 'biosncv': datasets = ["ENZYMES", "DD", "PROTEINS", # bioinformatics "COLLAB", "IMDB-BINARY", "IMDB-MULTI", # social networks "Letter-high", "Letter-low", "Letter-med"] # computer vision splitedData = {} df = pd.DataFrame() for data in datasets: if data == "COLLAB": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False)) elif data == "IMDB-BINARY": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False)) elif data == "IMDB-MULTI": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False)) elif "Letter" in data: tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True) else: tudataset = TUDataset(f"{datapath}/TUDataset", data) graphs = [x for x in tudataset] print(" **", data, len(graphs)) num_node_features = graphs[0].num_node_features graphs_chunks = _randChunk(graphs, nc_per_ds, overlap=False, seed=seed) for idx, chunks in enumerate(graphs_chunks): ds = f'{idx}-{data}' ds_tvt = chunks graphs_train, graphs_valtest = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed) graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed) graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init) graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init) graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init) dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True) dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True) dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True) num_graph_labels = get_numGraphLabels(graphs_train) splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test}, num_node_features, num_graph_labels, len(graphs_train)) df = get_stats(df, ds, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test) return splitedData, df def js_diver(P,Q): M=P+Q return 0.5*scipy.stats.entropy(P,M,base=2)+0.5*scipy.stats.entropy(Q,M,base=2) def setup_devices(splitedData, args): idx_clients = {} clients = [] for idx, ds in enumerate(splitedData.keys()): idx_clients[idx] = ds dataloaders, num_node_features, num_graph_labels, train_size = splitedData[ds] if args.alg == 'fedstar': cmodel_gc = GIN_dc(num_node_features, args.n_se, args.hidden, num_graph_labels, args.nlayer, args.dropout) else: cmodel_gc = GIN(num_node_features, args.hidden, num_graph_labels, args.nlayer, args.dropout) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cmodel_gc.parameters()), lr=args.lr, weight_decay=args.weight_decay) clients.append(Client_GC(cmodel_gc, idx, ds, train_size, dataloaders, optimizer, args)) if args.alg == 'fedstar': smodel = serverGIN_dc(n_se=args.n_se, nlayer=args.nlayer, nhid=args.hidden) else: smodel = serverGIN(nlayer=args.nlayer, nhid=args.hidden) server = Server(smodel, args.device) return clients, server, idx_clients def setup_devices_(splitedData,global_test_data, args): idx_clients = {} clients = [] for idx, ds in enumerate(splitedData.keys()): idx_clients[idx] = ds dataloaders, num_node_features, num_graph_labels, train_size = splitedData[ds] if args.alg == 'fedstar' or args.alg == 'ours1' or args.alg == 'ours2'or args.alg == 'agg_mean'or args.alg == 'agg_cluster2': cmodel_gc = GIN_dc(num_node_features, args.n_se, args.hidden, num_graph_labels, args.nlayer, args.dropout) else: cmodel_gc = GIN(num_node_features, args.hidden, num_graph_labels, args.nlayer, args.dropout) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cmodel_gc.parameters()), lr=args.lr, weight_decay=args.weight_decay) clients.append(Client_GC(cmodel_gc, idx, ds, train_size, dataloaders, optimizer, args)) if args.alg == 'fedstar' or args.alg == 'ours1' or args.alg == 'ours2'or args.alg == 'agg_mean' or args.alg == 'agg_cluster2': smodel = serverGIN_dc(num_node_features,n_se=args.n_se, nlayer=args.nlayer, nclass=num_graph_labels, nhid=args.hidden,dropout=args.dropout) else: smodel = serverGIN(nlayer=args.nlayer, nhid=args.hidden) server = Server(smodel,global_test_data, args.device) return clients, server, idx_clients def prepareData_oneDS_(args,datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False): if data == "COLLAB": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False)) elif data == "IMDB-BINARY": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False)) elif data == "IMDB-MULTI": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False)) else: tudataset = TUDataset(f"{datapath}/TUDataset", data) if convert_x: maxdegree = get_maxDegree(tudataset) tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False)) graphs = [x for x in tudataset] print(" **", data, len(graphs)) # ys=[] # for x in tudataset: # ys.append(x.y) # print(set(list(np.array(ys)))) # print(len(np.where(np.array(ys)==0)[0]),len(np.where(np.array(ys)==1)[0])) # exit() if data == 'ENZYMES': random.shuffle(graphs) graphs_chunks,graph_global_test = _randChunk_(graphs, num_client, overlap, seed=seed) splitedData = {} df = pd.DataFrame() num_node_features = graphs[0].num_node_features for idx, chunks in enumerate(graphs_chunks): ds = f'{idx}-{data}' ds_tvt = chunks # print(ds_tvt) if data=='ENZYMES': ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed) ds_val=ds_train ds_test=ds_train graphs_train=ds_train graphs_val=ds_val graphs_test=ds_test else: # ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed) # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed) graphs_train, graphs_valtest = split_data(ds_tvt, test=0.2, shuffle=True, seed=seed) graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed) # print(ds_vt) graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init) graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init) graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init) graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init) # exit() # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed) ds_train=graphs_train ds_val= graphs_val ds_test=graphs_test dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True) dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True) dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True) num_graph_labels = get_numGraphLabels(ds_train) splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test}, num_node_features, num_graph_labels, len(ds_train)) df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test) # ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed) # print(graph_global_test) dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)}) return splitedData, df,dataloader_global_test def prepareData_oneDS_ours_(args,datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False): # Load Dataset if data == "COLLAB": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False)) elif data == "IMDB-BINARY": # tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False)) dataset = TUDataset(f"{datapath}/TUDataset", data) maxdegree = get_maxDegree(dataset) tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False)) ft_dim = maxdegree + 1 elif data == "IMDB-MULTI": # tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False)) dataset = TUDataset(f"{datapath}/TUDataset", data) maxdegree = get_maxDegree(dataset) tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False)) ft_dim = maxdegree + 1 elif args.dataset.lower() == 'enzymes': tudataset = TUDataset(f"{datapath}/TUDataset", data) ft_dim = 3 elif args.dataset.lower() == 'mutag': tudataset = TUDataset(f"{datapath}/TUDataset", data) ft_dim = 7 elif args.dataset.lower() == 'proteins': tudataset = TUDataset(f"{datapath}/TUDataset", data) ft_dim = 3 else: tudataset = TUDataset(f"{datapath}/TUDataset", data) if convert_x: maxdegree = get_maxDegree(tudataset) tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False)) graphs = [x for x in tudataset] # print(" **", data, len(graphs)) graphs = [x for x in tudataset] splitedData, df, dataloader_global_test=data_oneDS_ours(args,tudataset,ft_dim) return splitedData, df,dataloader_global_test def prepareData_multiDS_(args, datapath, group='chem', batchSize=32, seed=None): assert group in ['chem', "biochem", 'biochemsn', "biosncv"] if group == 'chem': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"] elif group == 'biochem': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules "ENZYMES", "DD", "PROTEINS"] # bioinformatics elif group == 'biochemsn': datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1", # small molecules "ENZYMES", "DD", "PROTEINS", # bioinformatics "COLLAB", "IMDB-BINARY", "IMDB-MULTI"] # social networks elif group == 'biosncv': datasets = ["ENZYMES", "DD", "PROTEINS", # bioinformatics "COLLAB", "IMDB-BINARY", "IMDB-MULTI", # social networks "Letter-high", "Letter-low", "Letter-med"] # computer vision splitedData = {} df = pd.DataFrame() for data in datasets: if data == "COLLAB": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False)) elif data == "IMDB-BINARY": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False)) elif data == "IMDB-MULTI": tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False)) elif "Letter" in data: tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True) else: tudataset = TUDataset(f"{datapath}/TUDataset", data) graphs = [x for x in tudataset] print(" **", data, len(graphs)) graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed) graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed) graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init) graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init) graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init) dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True) dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True) dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True) num_node_features = graphs[0].num_node_features num_graph_labels = get_numGraphLabels(graphs_train) splitedData[data] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test}, num_node_features, num_graph_labels, len(graphs_train)) df = get_stats(df, data, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test) return splitedData, df