import os, sys import argparse import random import copy import torch from tensorboardX import SummaryWriter from pathlib import Path lib_dir = (Path(__file__).parent / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) import setupGC_ours as setupGC from training_ import * #CUDA_VISIBLE_DEVICES=1, python main_oneDS_ours_.py --data_group mutag --alg CeFGC def process_ours(args, clients, server, local_epoch): print("Start training ...") df = pd.DataFrame() allAccs_local,allAccs_global= run_ours2(args, clients, server, local_epoch) for k, v in allAccs_local.items(): df.loc[k, [f'train_acc', f'val_acc', f'loss',f'acc_sum',f'f1',f'precision',f'recall',f'auc']] = v print(df) if args.repeat is None: outfile = os.path.join(outpath, f'accuracy_'+args.alg+'_GC_local.csv') else: outfile = os.path.join(outpath, f'{args.repeat}_accuracy_'+args.alg+'_GC_local.csv') df.to_csv(outfile) print(f"Wrote to file: {outfile}") for k, v in allAccs_global.items(): df.loc[k, [f'train_acc', f'val_acc', f'loss',f'acc_sum',f'f1',f'precision',f'recall',f'auc']] = v print(df) if args.repeat is None: outfile = os.path.join(outpath, f'accuracy_'+args.alg+'_GC_global.csv') else: outfile = os.path.join(outpath, f'{args.repeat}_accuracy_'+args.alg+'_GC_global.csv') df.to_csv(outfile) print(f"Wrote to file: {outfile}") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='gpu', help='CPU / GPU device.') parser.add_argument('--alg', type=str, default='CeFGC', help='Name of algorithms, one of the CeFGC, CeFGC*, CeFGC+, CeFGC*+') parser.add_argument('--num_rounds', type=int, default=200, help='number of rounds to simulate;') parser.add_argument('--local_epoch', type=int, default=1, help='number of local epochs;') parser.add_argument('--lr', type=float, default=0.001, help='learning rate for inner solver;') parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).') parser.add_argument('--nlayer', type=int, default=3, help='Number of GINconv layers') parser.add_argument('--hidden', type=int, default=64, help='Number of hidden units.') parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).') parser.add_argument('--batch_size', type=int, default=128, help='Batch size for node classification.') parser.add_argument('--seed', help='seed for randomness;', type=int, default=1) parser.add_argument('--num_clients', help='number of clients', type=int, default=3) parser.add_argument('--datapath', type=str, default='./Data', help='The input path of data.') parser.add_argument('--outbase', type=str, default='./outputs-ours', help='The base path for outputting.') parser.add_argument('--repeat', help='index of repeating;', type=int, default=None) parser.add_argument('--data_group', help='specify the group of datasets', type=str, default='chem') parser.add_argument('--seq_length', help='the length of the gradient norm sequence', type=int, default=5) parser.add_argument('--n_rw', type=int, default=16, help='Size of position encoding (random walk).') parser.add_argument('--n_dg', type=int, default=16, help='Size of position encoding (max degree).') parser.add_argument('--n_ones', type=int, default=16, help='Size of position encoding (ones).') parser.add_argument('--type_init', help='the type of positional initialization', type=str, default='rw_dg', choices=['rw', 'dg', 'rw_dg', 'ones']) parser.add_argument('--convert_x', help='whether to convert original node features to one-hot degree features', type=bool, default=False) parser.add_argument('--overlap', help='whether clients have overlapped data', type=bool, default=False) parser.add_argument("--split", type=float, default="0.8", help="test/train dataset split percentage") parser.add_argument("--clients", type=int, default=3, help="number of clients") try: args = parser.parse_args() except IOError as msg: parser.error(str(msg)) args.dataset=args.data_group # set seeds seed_dataSplit = 123 random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # set device args.device = "cuda" if torch.cuda.is_available() else "cpu" # set output path outpath = os.path.join(args.outbase, 'raw', args.data_group) Path(outpath).mkdir(parents=True, exist_ok=True) print(f"Output Path: {outpath}") # preparing data print("Preparing data ...") # splitedData, df_stats = setupGC.prepareData_oneDS_(args, args.datapath, args.data_group, args.batch_size, seed=seed_dataSplit) splitedData, df_stats, dataloader_global_test = setupGC.prepareData_oneDS_ours_(args, args.datapath, args.data_group, num_client=args.num_clients, batchSize=args.batch_size, convert_x=args.convert_x, seed=seed_dataSplit, overlap=args.overlap) # splitedData, df_stats = setupGC.prepareData_multiDS(args, args.datapath, args.data_group, args.batch_size, seed=seed_dataSplit) print("Done") # save statistics of data on clients if args.repeat is None: outf = os.path.join(outpath, 'stats_trainData.csv') else: outf = os.path.join(outpath, f'{args.repeat}_stats_trainData.csv') df_stats.to_csv(outf) print(f"Wrote to {outf}") args.n_se = args.n_rw + args.n_dg init_clients, init_server, init_idx_clients = setupGC.setup_devices_(splitedData, dataloader_global_test,args) print("\nDone setting up devices.") # set summarywriter if 'fedstar' in args.alg: sw_path = os.path.join(args.outbase, 'raw', 'tensorboard', f'{args.data_group}_{args.alg}_{args.type_init}_{args.repeat}') else: sw_path = os.path.join(args.outbase, 'raw', 'tensorboard', f'{args.data_group}_{args.alg}_{args.repeat}') summary_writer = SummaryWriter(sw_path) process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000) # if args.alg == 'CeFGC': # process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000) # elif args.alg == 'CeFGC*': # process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000) # elif args.alg == 'CeFGC+': # process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000) # elif args.alg == 'CeFGC*+': # process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000)