import os, sys
import argparse
import random
import copy
import torch
from tensorboardX import SummaryWriter
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
import setupGC_ours as setupGC
from training_ import *
#CUDA_VISIBLE_DEVICES=1, python main_oneDS_ours_.py --data_group mutag --alg CeFGC
def process_ours(args, clients, server, local_epoch):
print("Start training ...")
df = pd.DataFrame()
allAccs_local,allAccs_global= run_ours2(args, clients, server, local_epoch)
for k, v in allAccs_local.items():
df.loc[k, [f'train_acc', f'val_acc', f'loss',f'acc_sum',f'f1',f'precision',f'recall',f'auc']] = v
print(df)
if args.repeat is None:
outfile = os.path.join(outpath, f'accuracy_'+args.alg+'_GC_local.csv')
else:
outfile = os.path.join(outpath, f'{args.repeat}_accuracy_'+args.alg+'_GC_local.csv')
df.to_csv(outfile)
print(f"Wrote to file: {outfile}")
for k, v in allAccs_global.items():
df.loc[k, [f'train_acc', f'val_acc', f'loss',f'acc_sum',f'f1',f'precision',f'recall',f'auc']] = v
print(df)
if args.repeat is None:
outfile = os.path.join(outpath, f'accuracy_'+args.alg+'_GC_global.csv')
else:
outfile = os.path.join(outpath, f'{args.repeat}_accuracy_'+args.alg+'_GC_global.csv')
df.to_csv(outfile)
print(f"Wrote to file: {outfile}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='gpu',
help='CPU / GPU device.')
parser.add_argument('--alg', type=str, default='CeFGC',
help='Name of algorithms, one of the CeFGC, CeFGC*, CeFGC+, CeFGC*+')
parser.add_argument('--num_rounds', type=int, default=200,
help='number of rounds to simulate;')
parser.add_argument('--local_epoch', type=int, default=1,
help='number of local epochs;')
parser.add_argument('--lr', type=float, default=0.001,
help='learning rate for inner solver;')
parser.add_argument('--weight_decay', type=float, default=5e-4,
help='Weight decay (L2 loss on parameters).')
parser.add_argument('--nlayer', type=int, default=3,
help='Number of GINconv layers')
parser.add_argument('--hidden', type=int, default=64,
help='Number of hidden units.')
parser.add_argument('--dropout', type=float, default=0.5,
help='Dropout rate (1 - keep probability).')
parser.add_argument('--batch_size', type=int, default=128,
help='Batch size for node classification.')
parser.add_argument('--seed', help='seed for randomness;',
type=int, default=1)
parser.add_argument('--num_clients', help='number of clients',
type=int, default=3)
parser.add_argument('--datapath', type=str, default='./Data',
help='The input path of data.')
parser.add_argument('--outbase', type=str, default='./outputs-ours',
help='The base path for outputting.')
parser.add_argument('--repeat', help='index of repeating;',
type=int, default=None)
parser.add_argument('--data_group', help='specify the group of datasets',
type=str, default='chem')
parser.add_argument('--seq_length', help='the length of the gradient norm sequence',
type=int, default=5)
parser.add_argument('--n_rw', type=int, default=16,
help='Size of position encoding (random walk).')
parser.add_argument('--n_dg', type=int, default=16,
help='Size of position encoding (max degree).')
parser.add_argument('--n_ones', type=int, default=16,
help='Size of position encoding (ones).')
parser.add_argument('--type_init', help='the type of positional initialization',
type=str, default='rw_dg', choices=['rw', 'dg', 'rw_dg', 'ones'])
parser.add_argument('--convert_x', help='whether to convert original node features to one-hot degree features',
type=bool, default=False)
parser.add_argument('--overlap', help='whether clients have overlapped data',
type=bool, default=False)
parser.add_argument("--split", type=float, default="0.8", help="test/train dataset split percentage")
parser.add_argument("--clients", type=int, default=3, help="number of clients")
try:
args = parser.parse_args()
except IOError as msg:
parser.error(str(msg))
args.dataset=args.data_group
# set seeds
seed_dataSplit = 123
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# set device
args.device = "cuda" if torch.cuda.is_available() else "cpu"
# set output path
outpath = os.path.join(args.outbase, 'raw', args.data_group)
Path(outpath).mkdir(parents=True, exist_ok=True)
print(f"Output Path: {outpath}")
# preparing data
print("Preparing data ...")
# splitedData, df_stats = setupGC.prepareData_oneDS_(args, args.datapath, args.data_group, args.batch_size, seed=seed_dataSplit)
splitedData, df_stats, dataloader_global_test = setupGC.prepareData_oneDS_ours_(args, args.datapath,
args.data_group,
num_client=args.num_clients,
batchSize=args.batch_size,
convert_x=args.convert_x,
seed=seed_dataSplit,
overlap=args.overlap)
# splitedData, df_stats = setupGC.prepareData_multiDS(args, args.datapath, args.data_group, args.batch_size, seed=seed_dataSplit)
print("Done")
# save statistics of data on clients
if args.repeat is None:
outf = os.path.join(outpath, 'stats_trainData.csv')
else:
outf = os.path.join(outpath, f'{args.repeat}_stats_trainData.csv')
df_stats.to_csv(outf)
print(f"Wrote to {outf}")
args.n_se = args.n_rw + args.n_dg
init_clients, init_server, init_idx_clients = setupGC.setup_devices_(splitedData, dataloader_global_test,args)
print("\nDone setting up devices.")
# set summarywriter
if 'fedstar' in args.alg:
sw_path = os.path.join(args.outbase, 'raw', 'tensorboard', f'{args.data_group}_{args.alg}_{args.type_init}_{args.repeat}')
else:
sw_path = os.path.join(args.outbase, 'raw', 'tensorboard', f'{args.data_group}_{args.alg}_{args.repeat}')
summary_writer = SummaryWriter(sw_path)
process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000)
# if args.alg == 'CeFGC':
# process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000)
# elif args.alg == 'CeFGC*':
# process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000)
# elif args.alg == 'CeFGC+':
# process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000)
# elif args.alg == 'CeFGC*+':
# process_ours(args, clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=1000)