CeFGC / lib / setupGC.py
setupGC.py
Raw
import random
from random import choices
import numpy as np
import pandas as pd

import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import OneHotDegree

from models_ import GIN, serverGIN, GIN_dc, serverGIN_dc
# from server import Server
from server_ import Server
from client import Client_GC
from utils import get_maxDegree, get_stats, split_data, get_numGraphLabels, init_structure_encoding

from scipy.special import rel_entr
import scipy
from torch_geometric.utils import erdos_renyi_graph, degree
import itertools
import time
import os
import pickle
import networkx as nx
from torch_geometric.data import Data


import copy as cp
def _randChunk(graphs, num_client, overlap, seed=None):
    random.seed(seed)
    np.random.seed(seed)

    totalNum = len(graphs)
    minSize = min(50, int(totalNum/num_client))
    graphs_chunks = []
    if not overlap:
        for i in range(num_client):
            graphs_chunks.append(graphs[i*minSize:(i+1)*minSize])
        for g in graphs[num_client*minSize:]:
            idx_chunk = np.random.randint(low=0, high=num_client, size=1)[0]
            graphs_chunks[idx_chunk].append(g)
    else:
        sizes = np.random.randint(low=50, high=150, size=num_client)
        for s in sizes:
            graphs_chunks.append(choices(graphs, k=s))
    return graphs_chunks


def _randChunk_(graphs, num_client, overlap, seed=None):
    random.seed(seed)
    np.random.seed(seed)
    # print(set(list(graphs.y)))
    totalNum = len(graphs)
    # totalNum_=int(totalNum*0.9)
    minSize = min(50, int(totalNum/num_client))
    graphs_chunks = []
    if not overlap:
        for i in range(num_client):
            graphs_chunks.append(graphs[i*minSize:(i+1)*minSize])
        for g in graphs[num_client*minSize:]:
            idx_chunk = np.random.randint(low=0, high=num_client, size=1)[0]
            graphs_chunks[idx_chunk].append(g)

        graph_global_test=graphs[num_client*minSize:]

    else:
        sizes = np.random.randint(low=50, high=150, size=num_client)
        for s in sizes:
            graphs_chunks.append(choices(graphs, k=s))
    return graphs_chunks,graph_global_test



def _randChunk2_(graphs, num_client, overlap, seed=None):
    random.seed(seed)
    np.random.seed(seed)
    # print(set(list(graphs.y)))
    totalNum = len(graphs)
    # totalNum_=int(totalNum*0.9)
    # minSize = min(100, int(totalNum/num_client))
    minSize = min(100, int(totalNum/num_client))
    # minSize = int(totalNum*0.6/num_client)

    graphs_chunks = []
    if not overlap:
        for i in range(num_client):
            graphs_chunks.append(graphs[i*minSize:(i+1)*minSize])
        for g in graphs[num_client*minSize:]:
            idx_chunk = np.random.randint(low=0, high=num_client, size=1)[0]
            graphs_chunks[idx_chunk].append(g)

        graph_global_test=graphs[num_client*minSize:]

    else:
        sizes = np.random.randint(low=50, high=150, size=num_client)
        for s in sizes:
            graphs_chunks.append(choices(graphs, k=s))
    return graphs_chunks,graph_global_test


def prepareData_oneDS(datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False):
    if data == "COLLAB":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
    elif data == "IMDB-BINARY":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
    elif data == "IMDB-MULTI":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
    else:
        tudataset = TUDataset(f"{datapath}/TUDataset", data)
        if convert_x:
            maxdegree = get_maxDegree(tudataset)
            tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
    graphs = [x for x in tudataset]
    print("  **", data, len(graphs))

    graphs_chunks = _randChunk(graphs, num_client, overlap, seed=seed)
    splitedData = {}
    df = pd.DataFrame()
    num_node_features = graphs[0].num_node_features
    for idx, chunks in enumerate(graphs_chunks):
        ds = f'{idx}-{data}'
        ds_tvt = chunks
        ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
        ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
        dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
        dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
        dataloader_test = DataLoader(ds_test, batch_size=batchSize, shuffle=True)
        num_graph_labels = get_numGraphLabels(ds_train)
        splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                           num_node_features, num_graph_labels, len(ds_train))
        df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)

    return splitedData, df

def prepareData_multiDS(args, datapath, group='chem', batchSize=32, seed=None):
    assert group in ['chem', "biochem", 'biochemsn', "biosncv"]

    if group == 'chem':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
    elif group == 'biochem':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1",  # small molecules
                    "ENZYMES", "DD", "PROTEINS"]                               # bioinformatics
    elif group == 'biochemsn':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1",  # small molecules
                    "ENZYMES", "DD", "PROTEINS",                               # bioinformatics
                    "COLLAB", "IMDB-BINARY", "IMDB-MULTI"]                     # social networks
    elif group == 'biosncv':
        datasets = ["ENZYMES", "DD", "PROTEINS",                               # bioinformatics
                    "COLLAB", "IMDB-BINARY", "IMDB-MULTI",                     # social networks
                    "Letter-high", "Letter-low", "Letter-med"]                 # computer vision

    splitedData = {}
    df = pd.DataFrame()
    for data in datasets:
        if data == "COLLAB":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
        elif data == "IMDB-BINARY":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
        elif data == "IMDB-MULTI":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
        elif "Letter" in data:
            tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True)
        else:
            tudataset = TUDataset(f"{datapath}/TUDataset", data)

        graphs = [x for x in tudataset]
        print("  **", data, len(graphs))

        graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed)
        graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

        graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
        graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
        graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)

        dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True)
        dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True)
        dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True)

        num_node_features = graphs[0].num_node_features
        num_graph_labels = get_numGraphLabels(graphs_train)

        splitedData[data] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                             num_node_features, num_graph_labels, len(graphs_train))

        df = get_stats(df, data, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test)

    return splitedData, df

def prepareData_multiDS_multi(args, datapath, group='small', batchSize=32, nc_per_ds=1, seed=None):
    assert group in ['chem', "biochem", 'biochemsn', "biosncv"]

    if group == 'chem':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
    elif group == 'biochem':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1",  # small molecules
                    "ENZYMES", "DD", "PROTEINS"]                               # bioinformatics
    elif group == 'biochemsn':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1",  # small molecules
                    "ENZYMES", "DD", "PROTEINS",                               # bioinformatics
                    "COLLAB", "IMDB-BINARY", "IMDB-MULTI"]                     # social networks
                    # "Letter-low", "Letter-med"]                                # computer vision
    elif group == 'biosncv':
        datasets = ["ENZYMES", "DD", "PROTEINS",                               # bioinformatics
                    "COLLAB", "IMDB-BINARY", "IMDB-MULTI",                     # social networks
                    "Letter-high", "Letter-low", "Letter-med"]                 # computer vision

    splitedData = {}
    df = pd.DataFrame()
    for data in datasets:
        if data == "COLLAB":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
        elif data == "IMDB-BINARY":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
        elif data == "IMDB-MULTI":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
        elif "Letter" in data:
            tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True)
        else:
            tudataset = TUDataset(f"{datapath}/TUDataset", data)

        graphs = [x for x in tudataset]
        print("  **", data, len(graphs))
        num_node_features = graphs[0].num_node_features
        graphs_chunks = _randChunk(graphs, nc_per_ds, overlap=False, seed=seed)
        for idx, chunks in enumerate(graphs_chunks):
            ds = f'{idx}-{data}'
            ds_tvt = chunks
            graphs_train, graphs_valtest = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

            graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
            graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
            graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)

            dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True)
            dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True)
            dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True)
            num_graph_labels = get_numGraphLabels(graphs_train)
            splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                               num_node_features, num_graph_labels, len(graphs_train))
            df = get_stats(df, ds, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test)

    return splitedData, df

def js_diver(P,Q):
    M=P+Q
    return 0.5*scipy.stats.entropy(P,M,base=2)+0.5*scipy.stats.entropy(Q,M,base=2)

def setup_devices(splitedData, args):
    idx_clients = {}
    clients = []
    for idx, ds in enumerate(splitedData.keys()):
        idx_clients[idx] = ds
        dataloaders, num_node_features, num_graph_labels, train_size = splitedData[ds]
        if args.alg == 'fedstar':
            cmodel_gc = GIN_dc(num_node_features, args.n_se, args.hidden, num_graph_labels, args.nlayer, args.dropout)
        else:
            cmodel_gc = GIN(num_node_features, args.hidden, num_graph_labels, args.nlayer, args.dropout)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cmodel_gc.parameters()), lr=args.lr, weight_decay=args.weight_decay)
        clients.append(Client_GC(cmodel_gc, idx, ds, train_size, dataloaders, optimizer, args))

    if args.alg == 'fedstar':
        smodel = serverGIN_dc(n_se=args.n_se, nlayer=args.nlayer, nhid=args.hidden)
    else:
        smodel = serverGIN(nlayer=args.nlayer, nhid=args.hidden)
    server = Server(smodel, args.device)
    return clients, server, idx_clients


def setup_devices_(splitedData,global_test_data, args):
    idx_clients = {}
    clients = []
    for idx, ds in enumerate(splitedData.keys()):
        idx_clients[idx] = ds
        dataloaders, num_node_features, num_graph_labels, train_size = splitedData[ds]
        if args.alg == 'fedstar':
            cmodel_gc = GIN_dc(num_node_features, args.n_se, args.hidden, num_graph_labels, args.nlayer, args.dropout)
        else:
            cmodel_gc = GIN(num_node_features, args.hidden, num_graph_labels, args.nlayer, args.dropout)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cmodel_gc.parameters()), lr=args.lr, weight_decay=args.weight_decay)
        clients.append(Client_GC(cmodel_gc, idx, ds, train_size, dataloaders, optimizer, args))

    if args.alg == 'fedstar':
        smodel = serverGIN_dc(num_node_features,n_se=args.n_se, nlayer=args.nlayer, nclass=num_graph_labels, nhid=args.hidden,dropout=args.dropout)
    else:
        smodel = serverGIN(nlayer=args.nlayer, nhid=args.hidden)
    server = Server(smodel,global_test_data, args.device)
    return clients, server, idx_clients

def prepareData_oneDS_(args,datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False):
    if data == "COLLAB":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
    elif data == "IMDB-BINARY":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
    elif data == "IMDB-MULTI":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
    else:
        tudataset = TUDataset(f"{datapath}/TUDataset", data)
        if convert_x:
            maxdegree = get_maxDegree(tudataset)
            tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
    graphs = [x for x in tudataset]
    # print("  **", data, len(graphs))
    # ys=[]
    # for x in tudataset:
    #     ys.append(x.y)

    # print(set(list(np.array(ys))))
    # print(len(np.where(np.array(ys)==0)[0]),len(np.where(np.array(ys)==1)[0]))
    # exit()
    # if data == 'ENZYMES':
    random.shuffle(graphs)

    graphs_chunks,graph_global_test = _randChunk_(graphs, num_client, overlap, seed=seed)
    splitedData = {}
    df = pd.DataFrame()
    num_node_features = graphs[0].num_node_features
    for idx, chunks in enumerate(graphs_chunks):
        ds = f'{idx}-{data}'
        ds_tvt = chunks
        # print(ds_tvt)
        if data=='ENZYMES':
            ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            ds_val=ds_train
            ds_test=ds_train
            graphs_train=ds_train
            graphs_val=ds_val
            graphs_test=ds_test
        else:
            # ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
            graphs_train, graphs_valtest = split_data(ds_tvt, test=0.2, shuffle=True, seed=seed)
            graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

        # print(graphs_train)
        # exit()
        graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
        graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
        graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
        graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init)

        # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)

        ds_train=graphs_train
        ds_val= graphs_val
        ds_test=graphs_test

        dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
        dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
        dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

        num_graph_labels = get_numGraphLabels(ds_train)
        splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                           num_node_features, num_graph_labels, len(ds_train))
        df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)
    # ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed)
    # print(graph_global_test)
    dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})

    return splitedData, df,dataloader_global_test


def prepareData_oneDS2_(args,datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False):
    if data == "COLLAB":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
    elif data == "IMDB-BINARY":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
    elif data == "IMDB-MULTI":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
    else:
        tudataset = TUDataset(f"{datapath}/TUDataset", data)
        if convert_x:
            maxdegree = get_maxDegree(tudataset)
            tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
    graphs = [x for x in tudataset]
    # print("  **", data, len(graphs))
    # ys=[]
    # for x in tudataset:
    #     ys.append(x.y)

    # print(set(list(np.array(ys))))
    # print(len(np.where(np.array(ys)==0)[0]),len(np.where(np.array(ys)==1)[0]))
    # exit()
    if data == 'ENZYMES':
        random.shuffle(graphs)

    graphs_chunks,graph_global_test = _randChunk2_(graphs, num_client, overlap, seed=seed)
    splitedData = {}
    df = pd.DataFrame()
    num_node_features = graphs[0].num_node_features
    for idx, chunks in enumerate(graphs_chunks):
        ds = f'{idx}-{data}'
        ds_tvt = chunks
        # print(ds_tvt)
        if data=='ENZYMES':
            ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            ds_val=ds_train
            ds_test=ds_train
            graphs_train=ds_train
            graphs_val=ds_val
            graphs_test=ds_test
        else:
            # ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
            graphs_train, graphs_valtest = split_data(ds_tvt, test=0.2, shuffle=True, seed=seed)
            graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

        print(graphs_train)
        # exit()
        graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
        graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
        graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
        graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init)

        # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)

        ds_train=graphs_train
        ds_val= graphs_val
        ds_test=graphs_test

        dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
        dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
        dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

        num_graph_labels = get_numGraphLabels(ds_train)
        splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                           num_node_features, num_graph_labels, len(ds_train))
        df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)
    # ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed)
    # print(graph_global_test)
    dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})

    return splitedData, df,dataloader_global_test

def prepareData_oneDS_time(args,datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False):
    if data == "COLLAB":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
    elif data == "IMDB-BINARY":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
    elif data == "IMDB-MULTI":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
    else:
        tudataset = TUDataset(f"{datapath}/TUDataset", data)
        if convert_x:
            maxdegree = get_maxDegree(tudataset)
            tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
    graphs = [x for x in tudataset]
    print("  **", data, len(graphs))
    # ys=[]
    # for x in tudataset:
    #     ys.append(x.y)

    # print(set(list(np.array(ys))))
    # print(len(np.where(np.array(ys)==0)[0]),len(np.where(np.array(ys)==1)[0]))
    # exit()
    if data == 'ENZYMES':
        random.shuffle(graphs)

    # graphs_chunks,graph_global_test = _randChunk2_(graphs, num_client, overlap, seed=seed)
    # splitedData = {}
    # df = pd.DataFrame()
    # num_node_features = graphs[0].num_node_features
    # for idx, chunks in enumerate(graphs_chunks):
    #     ds = f'{idx}-{data}'
    #     ds_tvt = chunks
    #     # print(ds_tvt)
    #     if data=='ENZYMES':
    #         ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
    #         ds_val=ds_train
    #         ds_test=ds_train
    #         graphs_train=ds_train
    #         graphs_val=ds_val
    #         graphs_test=ds_test
    #     else:
    #         # ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
    #         # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
    #         graphs_train, graphs_valtest = split_data(ds_tvt, test=0.2, shuffle=True, seed=seed)
    #         graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)
    #
    #     print(graphs_train)
    #     # exit()
    args.dataset=args.data_group
    data_dir = '/mnt/diskLv/luo/Federated-Learning-on-Graphs-main/Graph_Classification/federated/'
    file_name = args.dataset.lower() + '-train_feats_label_edge_list'
    file_path = os.path.join(data_dir, file_name)
    # feature_set = set()
    with open(file_path, 'rb') as f:
        feats_list_train0, label_list_train0, edge_list_train0 = pickle.load(f)

        for i in range(np.shape(label_list_train0)[0]):

            if np.max(edge_list_train0[i]) >= np.shape(feats_list_train0[i])[0]:
                print('error1')
        feats_list_train = [torch.tensor(ft_train) for ft_train in feats_list_train0]
        label_list_train = [torch.tensor(lb_train) for lb_train in label_list_train0]
        edge_list_train = [torch.tensor(eg_train) for eg_train in edge_list_train0]

    print('***', len(label_list_train))


    file_name = args.dataset.lower() + '-test_feats_label_edge_list'
    file_path = os.path.join(data_dir, file_name)
    # feature_set = set()
    with open(file_path, 'rb') as f:
        feats_list_test0, label_list_test0, edge_list_test0 = pickle.load(f)
        for i in range(np.shape(label_list_test0)[0]):

            if np.max(edge_list_test0[i]) >= np.shape(feats_list_test0[i])[0]:
                print('error2')
        # if len(label_list_train)==len(label_list_test0):
        #     print('error3')

        # print(np.shape(feats_list_test0[0]))
        feats_list_test = [torch.tensor(ft_test) for ft_test in feats_list_test0]
        label_list_test = [torch.tensor(lb_test) for lb_test in label_list_test0]
        edge_list_test = [torch.tensor(eg_test) for eg_test in edge_list_test0]
        # print(len(np.where(np.array(label_list_test0)[int(len(dataset) * args.split):] == 0)[0]),
        #       len(np.where(np.array(label_list_test0)[int(len(dataset) * args.split):] == 1)[0]))
        # exit()

        # test_dataset = []
        # for i in range(len(label_list_test[int(len(dataset) * args.split):])):
        #     test_dataset.append(Data(x=feats_list_test[i], edge_index=edge_list_test[i], y=label_list_test[i]))
    graph_global_test = []
    print(len(label_list_test))

    dataset=tudataset

    if args.data_group == 'ENZYMES' or args.data_group == 'MUTAG':
        for i in range(len(label_list_test[int(len(dataset) * args.split):])):
            edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
            x = torch.tensor(feats_list_test[i], dtype=torch.float)
            y = torch.tensor(label_list_test[i], dtype=torch.long)
            pyg_graph = Data(x=x, edge_index=edge_index, y=y)

            graph_global_test.append(pyg_graph)
            print(len(graph_global_test))
    else:
        for i in range(len(label_list_test)):
            edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
            x = torch.tensor(feats_list_test[i], dtype=torch.float)
            y = torch.tensor(label_list_test[i], dtype=torch.long)
            pyg_graph = Data(x=x, edge_index=edge_index, y=y)

            graph_global_test.append(pyg_graph)
            print(len(graph_global_test))

    num_node_features = graph_global_test[0].num_node_features


    graphs_train = []
    for jj in range(len(label_list_train)):
        edge_index = torch.tensor(edge_list_train[jj], dtype=torch.long)
        x = torch.tensor(feats_list_train[jj], dtype=torch.float)
        y = torch.tensor(label_list_train[jj], dtype=torch.long)
        pyg_graph = Data(x=x, edge_index=edge_index, y=y)

        # graph_global_test.append(pyg_graph)
        graphs_train.append(pyg_graph)

    # exit()
    startup = 0
    Client_list = []
    division = int(len(label_list_train) * args.split * 0.7 / args.clients)
    print(division)

    division_val = int(len(label_list_train) * args.split * 0.1 / args.clients)
    print('division_val',division_val)

    division_test = int(len(label_list_train) * args.split * 0.2 / args.clients)

    splitedData = {}
    df = pd.DataFrame()

    for i in range(args.clients):

        ds = f'{i}-{args.data_group}'

        client_data_train = graphs_train[startup:division + startup]
        client_data_val = graphs_train[division + startup:division + startup + division_val]
        client_data_test = graphs_train[
                           division + startup + division_val:division + startup + division_val + division_test]

        graphs_train = client_data_train
        graphs_val = client_data_val
        graphs_test = client_data_test
        # graph_global_test=graph_global_test

        graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
        graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
        graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
        graph_global_test = init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init)
        # exit()
        # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)

        ds_train = graphs_train
        ds_val = graphs_val
        ds_test = graphs_test

        dataloader_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True)
        dataloader_val = DataLoader(ds_val, batch_size=args.batch_size, shuffle=True)
        dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

        num_graph_labels = get_numGraphLabels(ds_train)
        splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                           num_node_features, num_graph_labels, len(ds_train))
        df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)
        # ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed)
        # print(graph_global_test)
    dataloader_global_test = (
    {'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})

    return splitedData, df, dataloader_global_test


def prepareData_oneDS_ours_103_(args,datapath, data, num_client, batchSize, convert_x=False, seed=None, overlap=False):
    if data == "COLLAB":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
    elif data == "IMDB-BINARY":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
    elif data == "IMDB-MULTI":
        tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
    else:
        tudataset = TUDataset(f"{datapath}/TUDataset", data)
        if convert_x:
            maxdegree = get_maxDegree(tudataset)
            tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
    graphs = [x for x in tudataset]
    print("  **", data, len(graphs))
    # ys=[]
    # for x in tudataset:
    #     ys.append(x.y)

    # print(set(list(np.array(ys))))
    # print(len(np.where(np.array(ys)==0)[0]),len(np.where(np.array(ys)==1)[0]))
    # exit()
    if data == 'ENZYMES':
        random.shuffle(graphs)

    graphs_chunks,graph_global_test = _randChunk_(graphs, num_client, overlap, seed=seed)
    splitedData = {}
    df = pd.DataFrame()
    num_node_features = graphs[0].num_node_features
    for idx, chunks in enumerate(graphs_chunks):
        ds = f'{idx}-{data}'
        ds_tvt = chunks
        # print(ds_tvt)
        if data=='ENZYMES':
            ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            ds_val=ds_train
            ds_test=ds_train
            graphs_train=ds_train
            graphs_val=ds_val
            graphs_test=ds_test
        else:
            # ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
            graphs_train, graphs_valtest = split_data(ds_tvt, test=0.2, shuffle=True, seed=seed)
            graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

        # print(ds_vt)
        graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
        graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
        graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)
        graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init)
        # exit()
        # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)

        ds_train=graphs_train
        ds_val= graphs_val
        ds_test=graphs_test

        dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
        dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
        dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

        num_graph_labels = get_numGraphLabels(ds_train)
        splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                           num_node_features, num_graph_labels, len(ds_train))
        df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)
    # ds_global_test=split_data(graph_global_test, test=2, shuffle=True, seed=seed)
    # print(graph_global_test)
    dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})

    return splitedData, df,dataloader_global_test




def prepareData_multiDS_(args, datapath, group='chem', batchSize=32, seed=None):
    assert group in ['chem', "biochem", 'biochemsn', "biosncv"]

    if group == 'chem':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
    elif group == 'biochem':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1",  # small molecules
                    "ENZYMES", "DD", "PROTEINS"]                               # bioinformatics
    elif group == 'biochemsn':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1",  # small molecules
                    "ENZYMES", "DD", "PROTEINS",                               # bioinformatics
                    "COLLAB", "IMDB-BINARY", "IMDB-MULTI"]                     # social networks
    elif group == 'biosncv':
        datasets = ["ENZYMES", "DD", "PROTEINS",                               # bioinformatics
                    "COLLAB", "IMDB-BINARY", "IMDB-MULTI",                     # social networks
                    "Letter-high", "Letter-low", "Letter-med"]                 # computer vision

    splitedData = {}
    df = pd.DataFrame()
    for data in datasets:
        if data == "COLLAB":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
        elif data == "IMDB-BINARY":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
        elif data == "IMDB-MULTI":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
        elif "Letter" in data:
            tudataset = TUDataset(f"{datapath}/TUDataset", data, use_node_attr=True)
        else:
            tudataset = TUDataset(f"{datapath}/TUDataset", data)

        graphs = [x for x in tudataset]
        print("  **", data, len(graphs))

        graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed)
        graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

        graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
        graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
        graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)

        dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True)
        dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True)
        dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True)

        num_node_features = graphs[0].num_node_features
        num_graph_labels = get_numGraphLabels(graphs_train)

        splitedData[data] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                             num_node_features, num_graph_labels, len(graphs_train))

        df = get_stats(df, data, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test)

    return splitedData, df

def prepareData_multiDS_protein_(args,datapath,num_client,group='small',batchSize=32, convert_x=False, seed=None,overlap=False):
    # assert group in ['molecules', 'molecules_tiny', 'small', 'mix', "mix_tiny", "biochem", "biochem_tiny",'social']
    # print('###',group)
    if group == 'molecules' or group == 'molecules_tiny':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
        ft_dim=3
    if group == 'small':
        datasets = ["MUTAG",                  # small molecules
                    "ENZYMES"]                                # bioinformatics
        ft_dim = 3
    if group == 'PROTEINS1':
        datasets = ["MUTAG",                  # small molecules
                    "PROTEINS"]
        ft_dim = 3
    if group == 'PROTEINS2':
        datasets = ["MUTAG",                  # small molecules
                    "ENZYMES",
                    "PROTEINS"]
        ft_dim = 3
    if group == 'PROTEINS3':
        datasets = ["ENZYMES",
                    "PROTEINS"]                  # small molecules
        ft_dim = 3

    if group == 'mix' or group == 'mix_tiny':
        datasets = ["MUTAG",   # small molecules
                    "IMDB-BINARY", "IMDB-MULTI"]                      # social networks
        ft_dim = 3
    if group == 'biochem' or group == 'biochem_tiny':
        datasets = ["MUTAG",   # small molecules
                    "ENZYMES"]                               # bioinformatics
        ft_dim = 3
    if group == 'social' :
        datasets = ["IMDB-BINARY", "IMDB-MULTI"]
        ft_dim = 89
    if group == 'mix_all' :
        datasets = ["MUTAG","IMDB-BINARY", "IMDB-MULTI","ENZYMES"]
        ft_dim = 3

    splitedData = {}
    df = pd.DataFrame()
    graph_global_test=[]
    ii=0
    for data in datasets:
        if data == "COLLAB":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
        elif data == "IMDB-BINARY":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
        elif data == "IMDB-MULTI":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
        else:
            tudataset = TUDataset(f"{datapath}/TUDataset", data)
            if convert_x:
                maxdegree = get_maxDegree(tudataset)
                tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))

        graphs = [x for x in tudataset]
        # print("  **", data, len(graphs))
        # graphs=[]
        # for g_x in tudataset:
        #     # print(g_x)
        #     g_x.x=g_x.x[:,0:ft_dim]
        #     g_x.num_nodes=np.shape(g_x.x)[0]
        #     g_x.edge_attr=g_x.edge_attr
        #     print(g_x)
        #     graphs.append(g_x)
        #     # exit()

        graphs_chunks, graph_global_test_ = _randChunk2_(graphs, num_client, overlap=False,seed=seed)
        # splitedData = {}
        # df = pd.DataFrame()
        num_node_features = graphs[0].num_node_features
        for idx, chunks in enumerate(graphs_chunks):
            ds = f'{idx}-{data}'
            ds_tvt = chunks
            # print(ds_tvt)

            if data == 'ENZYMES':
                ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
                ds_val = ds_train
                ds_test = ds_train
                graphs_train = ds_train
                graphs_val = ds_val
                graphs_test = ds_test
            else:
                ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
                ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
                graphs_train, graphs_valtest = split_data(ds_tvt, test=0.2, shuffle=True, seed=seed)
                graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

            # ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            # # print(ds_vt)
            #
            # # exit()
            # ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
            #
            # graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed)
            # graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

            graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
            graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
            graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)

            ds_train=graphs_train
            ds_val=graphs_val
            ds_test=graphs_test

            dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
            dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
            dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

            num_graph_labels = get_numGraphLabels(ds_train)
            splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                               num_node_features, num_graph_labels, len(ds_train))
            df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)

        # print(graph_global_test_)
        # if ii == 0:
        #     graph_global_test = graph_global_test_
        # else:
        #     graph_global_test = torch.cat((graph_global_test, graph_global_test_))
        graph_global_test.append(graph_global_test_)
        ii+=1

    graph_global_test=list(itertools.chain.from_iterable(graph_global_test))
    dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})

    return splitedData, df,dataloader_global_test



def prepareData_multiDS_protein_loss_(args,datapath,num_client,group='small',batchSize=32, convert_x=False, seed=None,overlap=False):
    # assert group in ['molecules', 'molecules_tiny', 'small', 'mix', "mix_tiny", "biochem", "biochem_tiny",'social']
    # print('###',group)
    if group == 'molecules' or group == 'molecules_tiny':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
        ft_dim=3
    if group == 'small':
        datasets = ["MUTAG",                  # small molecules
                    "ENZYMES"]                                # bioinformatics
        ft_dim = 3
    if group == 'PROTEINS1':
        datasets = ["MUTAG",                  # small molecules
                    "PROTEINS"]
        ft_dim = 3
    if group == 'PROTEINS2':
        datasets = ["MUTAG",                  # small molecules
                    "ENZYMES",
                    "PROTEINS"]
        ft_dim = 3
    if group == 'PROTEINS3':
        datasets = ["ENZYMES",
                    "PROTEINS"]                  # small molecules
        ft_dim = 3

    if group == 'mix' or group == 'mix_tiny':
        datasets = ["MUTAG",   # small molecules
                    "IMDB-BINARY", "IMDB-MULTI"]                      # social networks
        ft_dim = 3
    if group == 'biochem' or group == 'biochem_tiny':
        datasets = ["MUTAG",   # small molecules
                    "ENZYMES"]                               # bioinformatics
        ft_dim = 3
    if group == 'social' :
        datasets = ["IMDB-BINARY", "IMDB-MULTI"]
        ft_dim = 89
    if group == 'mix_all' :
        datasets = ["MUTAG","ENZYMES","IMDB-BINARY", "IMDB-MULTI"]
        ft_dim = 3

    splitedData = {}
    df = pd.DataFrame()
    graph_global_test=[]
    tudataset_list = []
    # tudataset_num = 0
    for data in datasets:

        if data == "COLLAB":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
            tudataset_list.append(tudataset)
            tudataset_num = len(tudataset)
        elif data == "IMDB-BINARY":

            # tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
            dataset = TUDataset(f"{datapath}/TUDataset", data)
            maxdegree = get_maxDegree(dataset)
            tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
            # ft_dim = maxdegree + 1
            tudataset_list.append(tudataset)
            tudataset_num = len(tudataset)


        elif data == "IMDB-MULTI":
            # tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
            dataset = TUDataset(f"{datapath}/TUDataset", data)
            maxdegree = get_maxDegree(dataset)
            tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
            # ft_dim = maxdegree + 1
            tudataset_list.append(tudataset)
            tudataset_num = len(tudataset)

        elif data.lower() == 'enzymes':
            tudataset = TUDataset(f"{datapath}/TUDataset", data)
            # ft_dim = 3
            tudataset_list.append(tudataset)
            tudataset_num = len(tudataset)

        elif data.lower() == 'mutag':
            tudataset = TUDataset(f"{datapath}/TUDataset", data)
            # ft_dim = 7
            tudataset_list.append(tudataset)
            tudataset_num = len(tudataset)

        elif data.lower() == 'proteins':
            tudataset = TUDataset(f"{datapath}/TUDataset", data)
            # ft_dim = 3
            tudataset_list.append(tudataset)
            tudataset_num = len(tudataset)

        else:
            tudataset = TUDataset(f"{datapath}/TUDataset", data)
            if convert_x:
                maxdegree = get_maxDegree(tudataset)
                tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))
            tudataset_num = len(tudataset)

            tudataset_list.append(tudataset)
        graphs = [x for x in tudataset]
        # print("  **", data, len(graphs))
        # graphs=[]
        # for g_x in tudataset:
        #     # print(g_x)
        #     g_x.x=g_x.x[:,0:ft_dim]
        #     g_x.num_nodes=np.shape(g_x.x)[0]
        #     g_x.edge_attr=g_x.edge_attr
        #     print(g_x)
        #     graphs.append(g_x)
        #     # exit()
        if group != 'mix_all':
            graphs_chunks, graph_global_test_ = _randChunk2_(graphs, num_client, overlap=False,seed=seed)
            # splitedData = {}
            # df = pd.DataFrame()
            num_node_features = ft_dim
            for idx, chunks in enumerate(graphs_chunks):
                ds = f'{idx}-{data}'
                ds_tvt = chunks
                # print(ds_tvt)

                if data == 'ENZYMES':
                    ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
                    ds_val = ds_train
                    ds_test = ds_train
                    graphs_train = ds_train
                    graphs_val = ds_val
                    graphs_test = ds_test
                else:
                    ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
                    ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)
                    graphs_train, graphs_valtest = split_data(ds_tvt, test=0.2, shuffle=True, seed=seed)
                    graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

                graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
                graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
                graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)

                ds_train = graphs_train
                ds_val = graphs_val
                ds_test = graphs_test

                dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
                dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
                dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

                num_graph_labels = get_numGraphLabels(ds_train)
                splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                                   num_node_features, num_graph_labels, len(ds_train))
                df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)

            graph_global_test.append(graph_global_test_)


        else:
            j_data=0
            for data in datasets:

                data_dir = '/mnt/diskLv/luo/Federated-Learning-on-Graphs-main/Graph_Classification/federated/'
                file_name = data.lower() + '-train_feats_label_edge_list'
                file_path = os.path.join(data_dir, file_name)
                # feature_set = set()
                with open(file_path, 'rb') as f:
                    feats_list_train0, label_list_train0, edge_list_train0 = pickle.load(f)

                    for i in range(np.shape(label_list_train0)[0]):

                        if np.max(edge_list_train0[i]) >= np.shape(feats_list_train0[i])[0]:
                            print('error1')
                    feats_list_train = [torch.tensor(ft_train[:, 0:ft_dim]) for ft_train in feats_list_train0]
                    label_list_train = [torch.tensor(lb_train) for lb_train in label_list_train0]
                    edge_list_train = [torch.tensor(eg_train) for eg_train in edge_list_train0]

                print('***', len(label_list_train))

                file_name = data.lower() + '-test_feats_label_edge_list'
                file_path = os.path.join(data_dir, file_name)
                # feature_set = set()
                with open(file_path, 'rb') as f:
                    feats_list_test0, label_list_test0, edge_list_test0 = pickle.load(f)
                    for i in range(np.shape(label_list_test0)[0]):

                        if np.max(edge_list_test0[i]) >= np.shape(feats_list_test0[i])[0]:
                            print('error2')
                    # if len(label_list_train)==len(label_list_test0):
                    #     print('error3')

                    # print(np.shape(feats_list_test0[0]))
                    feats_list_test = [torch.tensor(ft_test[:, 0:ft_dim]) for ft_test in feats_list_test0]
                    label_list_test = [torch.tensor(lb_test) for lb_test in label_list_test0]
                    edge_list_test = [torch.tensor(eg_test) for eg_test in edge_list_test0]
                # print(len(label_list_test))

                if data == 'MUTAG':
                    print(len(label_list_test))
                    for i in range(len(label_list_test[int(tudataset_num * args.split):])):
                        edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
                        x = torch.tensor(feats_list_test[i], dtype=torch.float)
                        y = torch.tensor(label_list_test[i], dtype=torch.long)
                        pyg_graph = Data(x=x, edge_index=edge_index, y=y)

                        graph_global_test.append(pyg_graph)
                        # print(len(graph_global_test))
                elif data == 'ENZYMES':
                    print(len(label_list_test))
                    for i in range(len(label_list_test[int(len(label_list_test) * 0.5):])):
                        edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
                        x = torch.tensor(feats_list_test[i], dtype=torch.float)
                        y = torch.tensor(label_list_test[i], dtype=torch.long)
                        pyg_graph = Data(x=x, edge_index=edge_index, y=y)
                        # print(y)
                        graph_global_test.append(pyg_graph)

                    # exit()
                else:
                    print(len(label_list_test))
                    for i in range(len(label_list_test)):
                        edge_index = torch.tensor(edge_list_test[i], dtype=torch.long)
                        x = torch.tensor(feats_list_test[i], dtype=torch.float)
                        y = torch.tensor(label_list_test[i], dtype=torch.long)
                        pyg_graph = Data(x=x, edge_index=edge_index, y=y)

                        graph_global_test.append(pyg_graph)
                        # print(len(graph_global_test))

                # num_node_features =graph_global_test[0].num_node_features

                # Train Dataset split to Clients

                graphs_train = []
                for jj in range(len(label_list_train)):
                    edge_index = torch.tensor(edge_list_train[jj], dtype=torch.long)
                    x = torch.tensor(feats_list_train[jj], dtype=torch.float)
                    y = torch.tensor(label_list_train[jj], dtype=torch.long)
                    pyg_graph = Data(x=x, edge_index=edge_index, y=y)

                    graphs_train.append(pyg_graph)

                # exit()
                startup = 0
                Client_list = []
                division = int(len(label_list_train) * args.split * 0.8 / args.clients)
                # print(division)

                division_val = int(len(label_list_train) * args.split * 0.1 / args.clients)
                division_test = int(len(label_list_train) * args.split * 0.1 / args.clients)

                args.clients=args.num_clients

                graphs_train_ = cp.deepcopy(graphs_train)

                for i in range(args.clients):

                    ds = f'{i + j_data * 3}-{args.data_group}'

                    client_data_train = graphs_train_[startup:division + startup]
                    client_data_val = graphs_train_[division + startup:division + startup + division_val]

                    # print('~~~',len(graphs_train_),startup)

                    if (args.data_group == 'PROTEINS3' and data == 'ENZYMES') or (
                            args.data_group == 'mix_all' and data == 'ENZYMES'):

                        # print('999999')
                        client_data_train = graphs_train_
                        client_data_test = client_data_train

                        # exit()
                    else:
                        client_data_test = graphs_train_[
                                           division + startup + division_val:division + startup + division_val + division_test]

                    startup=division + startup + division_val + division_test

                    graphs_train = client_data_train
                    graphs_val = client_data_val
                    graphs_test = client_data_test

                    graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
                    graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
                    graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)

                    ds_train=graphs_train
                    ds_val=graphs_val
                    ds_test=graphs_test

                    # print('@@@',len(ds_train))

                    dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
                    dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
                    dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

                    num_graph_labels = get_numGraphLabels(ds_train)


                    splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                                       ft_dim, num_graph_labels, len(ds_train))
                    df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)

                j_data += 1

        # graph_global_test = list(itertools.chain.from_iterable(graph_global_test))
        #
    graph_global_test=init_structure_encoding(args, gs=graph_global_test, type_init=args.type_init)
    print(len(graph_global_test))

    # graph_global_test = list(itertools.chain.from_iterable(graph_global_test))

    # dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})
    dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=batchSize, shuffle=True)})

    print(len(dataloader_global_test['test']))

    # exit()
    return splitedData, df,dataloader_global_test

def prepareData_multiDS_protein_time(args,datapath,num_client,group='small',batchSize=32, convert_x=False, seed=None,overlap=False):
    # assert group in ['molecules', 'molecules_tiny', 'small', 'mix', "mix_tiny", "biochem", "biochem_tiny",'social']
    # print('###',group)
    if group == 'molecules' or group == 'molecules_tiny':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
    if group == 'small':
        datasets = ["MUTAG",                  # small molecules
                    "ENZYMES"]                                # bioinformatics
    if group == 'PROTEINS1':
        datasets = ["MUTAG",                  # small molecules
                    "PROTEINS"]
    if group == 'PROTEINS2':
        datasets = ["MUTAG",                  # small molecules
                    "ENZYMES",
                    "PROTEINS"]
    if group == 'PROTEINS3':
        datasets = ["ENZYMES",
                    "PROTEINS"]                  # small molecules

    if group == 'mix' or group == 'mix_tiny':
        datasets = ["MUTAG",   # small molecules
                    "IMDB-BINARY", "IMDB-MULTI"]                      # social networks
    if group == 'biochem' or group == 'biochem_tiny':
        datasets = ["MUTAG",   # small molecules
                    "ENZYMES"]                               # bioinformatics
    if group == 'social' :
        datasets = ["IMDB-BINARY", "IMDB-MULTI"]

    if group == 'mix_all' :
        datasets = ["MUTAG","IMDB-BINARY", "IMDB-MULTI","ENZYMES"]


    splitedData = {}
    df = pd.DataFrame()
    graph_global_test=[]
    ii=0
    for data in datasets:
        if data == "COLLAB":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
        elif data == "IMDB-BINARY":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
        elif data == "IMDB-MULTI":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
        else:
            tudataset = TUDataset(f"{datapath}/TUDataset", data)
            if convert_x:
                maxdegree = get_maxDegree(tudataset)
                tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))

        graphs = [x for x in tudataset]
        # print("  **", data, len(graphs))

        graphs_chunks, graph_global_test_ = _randChunk2_(graphs, num_client, overlap=False,seed=seed)
        # splitedData = {}
        # df = pd.DataFrame()
        num_node_features = graphs[0].num_node_features
        for idx, chunks in enumerate(graphs_chunks):
            ds = f'{idx}-{data}'
            ds_tvt = chunks
            # print(ds_tvt)
            ds_train, ds_vt = split_data(ds_tvt, train=0.8, test=0.2, shuffle=True, seed=seed)
            # print(ds_vt)

            # exit()
            ds_val, ds_test = split_data(ds_vt, train=0.5, test=0.5, shuffle=True, seed=seed)

            graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed)
            graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

            graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
            graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
            graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)

            ds_train=graphs_train
            ds_val=graphs_val
            ds_test=graphs_test

            dataloader_train = DataLoader(ds_train, batch_size=batchSize, shuffle=True)
            dataloader_val = DataLoader(ds_val, batch_size=batchSize, shuffle=True)
            dataloader_test = DataLoader(ds_test, batch_size=len(ds_test), shuffle=True)

            num_graph_labels = get_numGraphLabels(ds_train)
            splitedData[ds] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                               num_node_features, num_graph_labels, len(ds_train))
            df = get_stats(df, ds, ds_train, graphs_val=ds_val, graphs_test=ds_test)

        # print(graph_global_test_)
        # if ii == 0:
        #     graph_global_test = graph_global_test_
        # else:
        #     graph_global_test = torch.cat((graph_global_test, graph_global_test_))
        graph_global_test.append(graph_global_test_)
        ii+=1

        # graphs_train, graph_global_test_ = split_data(graphs, test=0.2, shuffle=True, seed=seed)
        # # print('###',graph_global_test_[0].x)
        # # exit()
        #
        # graphs_train, graphs_valtest = split_data(graphs_train, test=0.2, shuffle=True, seed=seed)
        # graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)
        # if group.endswith('tiny'):
        #     graphs, _ = split_data(graphs, train=150, shuffle=True, seed=seed)
        #     graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed)
        #     graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)
        #
        # num_node_features = graphs[0].num_node_features
        # num_graph_labels = get_numGraphLabels(graphs_train)
        #
        # dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True)
        # dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True)
        # dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True)
        #
        # splitedData[data] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
        #                      num_node_features, num_graph_labels, len(graphs_train))
        #
        # df = get_stats(df, data, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test)
        #
        # if ii==0:
        #     graph_global_test=graph_global_test_
        # else:
        #
        #     graph_global_test =torch.cat(( graph_global_test,graph_global_test_))

    # print(graph_global_test)
    # exit()
    # graph_global_test=graph_global_test.view(-1)
    graph_global_test=list(itertools.chain.from_iterable(graph_global_test))
    dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})

    return splitedData, df,dataloader_global_test



def prepareData_multiDS_protein(datapath, group='small', batchSize=32, convert_x=False, seed=None):
    assert group in ['molecules', 'molecules_tiny', 'small', 'mix', "mix_tiny", "biochem", "biochem_tiny",'social']

    if group == 'molecules' or group == 'molecules_tiny':
        datasets = ["MUTAG", "BZR", "COX2", "DHFR", "PTC_MR", "AIDS", "NCI1"]
    if group == 'small':
        datasets = ["MUTAG",                  # small molecules
                    "ENZYMES"]                                # bioinformatics
    if group == 'mix' or group == 'mix_tiny':
        datasets = ["MUTAG",   # small molecules
                    "ENZYMES",                               # bioinformatics
                    "IMDB-BINARY", "IMDB-MULTI"]                      # social networks
    if group == 'biochem' or group == 'biochem_tiny':
        datasets = ["MUTAG",   # small molecules
                    "ENZYMES"]                               # bioinformatics
    if group == 'social' :
        datasets = ["IMDB-BINARY", "IMDB-MULTI"]

    splitedData = {}
    df = pd.DataFrame()
    ii=0
    for data in datasets:
        if data == "COLLAB":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(491, cat=False))
        elif data == "IMDB-BINARY":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(135, cat=False))
        elif data == "IMDB-MULTI":
            tudataset = TUDataset(f"{datapath}/TUDataset", data, pre_transform=OneHotDegree(88, cat=False))
        else:
            tudataset = TUDataset(f"{datapath}/TUDataset", data)
            if convert_x:
                maxdegree = get_maxDegree(tudataset)
                tudataset = TUDataset(f"{datapath}/TUDataset", data, transform=OneHotDegree(maxdegree, cat=False))

        graphs = [x for x in tudataset]
        # print("  **", data, len(graphs))


        graphs_train, graph_global_test_ = split_data(graphs, test=0.2, shuffle=True, seed=seed)
        # print('###',graph_global_test_[0].x)
        # exit()

        graphs_train, graphs_valtest = split_data(graphs_train, test=0.2, shuffle=True, seed=seed)
        graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)
        if group.endswith('tiny'):
            graphs, _ = split_data(graphs, train=150, shuffle=True, seed=seed)
            graphs_train, graphs_valtest = split_data(graphs, test=0.2, shuffle=True, seed=seed)
            graphs_val, graphs_test = split_data(graphs_valtest, train=0.5, test=0.5, shuffle=True, seed=seed)

        num_node_features = graphs[0].num_node_features
        num_graph_labels = get_numGraphLabels(graphs_train)

        graphs_train = init_structure_encoding(args, gs=graphs_train, type_init=args.type_init)
        graphs_val = init_structure_encoding(args, gs=graphs_val, type_init=args.type_init)
        graphs_test = init_structure_encoding(args, gs=graphs_test, type_init=args.type_init)

        ds_train = graphs_train
        ds_val = graphs_val
        ds_test = graphs_test

        dataloader_train = DataLoader(graphs_train, batch_size=batchSize, shuffle=True)
        dataloader_val = DataLoader(graphs_val, batch_size=batchSize, shuffle=True)
        dataloader_test = DataLoader(graphs_test, batch_size=batchSize, shuffle=True)

        splitedData[data] = ({'train': dataloader_train, 'val': dataloader_val, 'test': dataloader_test},
                             num_node_features, num_graph_labels, len(graphs_train))

        df = get_stats(df, data, graphs_train, graphs_val=graphs_val, graphs_test=graphs_test)

        if ii==0:
            graph_global_test=graph_global_test_
        else:

            graph_global_test =torch.cat(( graph_global_test,graph_global_test_))

    # print(graph_global_test)
    # exit()
    # graph_global_test=graph_global_test.view(-1)
    dataloader_global_test = ({'test': DataLoader(graph_global_test, batch_size=len(graph_global_test), shuffle=True)})

    return splitedData, df,dataloader_global_test