import pandas as pd import numpy as np import matplotlib.pyplot as plt def run_selftrain_GC(args, clients, server, local_epoch): # all clients are initialized with the same weights for client in clients: client.download_from_server(args, server) allAccs = {} for client in clients: client.local_train(local_epoch) loss, acc = client.evaluate() allAccs[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc] print(" > {} done.".format(client.name)) return allAccs def run_fedavg(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None): for client in clients: client.download_from_server(args, server) if samp is None: sampling_fn = server.randomSample_clients frac = 1.0 for c_round in range(1, COMMUNICATION_ROUNDS + 1): if (c_round) % 50 == 0: print(f" > round {c_round}") if c_round == 1: selected_clients = clients else: selected_clients = sampling_fn(clients, frac) for client in selected_clients: # only get weights of graphconv layers client.local_train(local_epoch) server.aggregate_weights(selected_clients) for client in selected_clients: client.download_from_server(args, server) # write to log files if c_round % 5 == 0: for idx in range(len(clients)): loss, acc = clients[idx].evaluate() summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round) summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round) frame = pd.DataFrame() for client in clients: loss, acc = client.evaluate() frame.loc[client.name, 'test_acc'] = acc def highlight_max(s): is_max = s == s.max() return ['background-color: yellow' if v else '' for v in is_max] fs = frame.style.apply(highlight_max).data print(fs) return frame def run_fedstar(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None): for client in clients: client.download_from_server(args, server) if samp is None: sampling_fn = server.randomSample_clients frac = 1.0 for c_round in range(1, COMMUNICATION_ROUNDS + 1): if (c_round) % 50 == 0: print(f" > round {c_round}") if c_round == 1: selected_clients = clients else: selected_clients = sampling_fn(clients, frac) for client in selected_clients: # only get weights of graphconv layers client.local_train(local_epoch) server.aggregate_weights_se(selected_clients) for client in selected_clients: client.download_from_server(args, server) # write to log files if c_round % 5 == 0: for idx in range(len(clients)): loss, acc = clients[idx].evaluate() summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round) summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round) frame = pd.DataFrame() for client in clients: loss, acc = client.evaluate() frame.loc[client.name, 'test_acc'] = acc def highlight_max(s): is_max = s == s.max() return ['background-color: yellow' if v else '' for v in is_max] fs = frame.style.apply(highlight_max).data print(fs) loss, acc, f1, precision, recall, auc = server_.eval_gc_(self.model, self.dataLoader['test'], self.device) return frame