from distFedPAQ.nodes import Node from distFedPAQ.datasets import pack_data, data_splitter, generate_dummy_data from distFedPAQ.utils.plots import custom_plot, graph from distFedPAQ.utils.argument_parser import arg_parser from distFedPAQ.functional.loss_functions import mse_loss, grad_mse_loss from distFedPAQ.functional.tools import repeat import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import load_diabetes from sklearn.model_selection import train_test_split from collections import defaultdict from tqdm import tqdm if __name__ == "__main__": # > PARAMS FETCHING ( n, n_loc_update, n_ext_update, n_nodes_ext_ave, batch_size, lr, momentum, add_ones, trainer, P, single_times, ) = arg_parser() # > DATA LOADING AND PREPARATION X, y = load_diabetes(return_X_y=True) # X, y = generate_dummy_data(size=300) full_data, weight_size = pack_data(X, y, add_ones) train_data, test_data = train_test_split(full_data, test_size=0.2) # > NODES DEFINITION # #> single single_node = Node( data=data_splitter(train_data, 1), weight_size=weight_size, grad_f=grad_mse_loss, batch_size=batch_size, ) # #> multiple nodes = np.array( [ Node( data=data, weight_size=weight_size, grad_f=grad_mse_loss, batch_size=batch_size, ) for data in data_splitter(train_data, n) ], dtype=Node, ) # > OUTPUTS PREPARATIONS losses_train = np.empty((n + 1, n_ext_update + 1)) losses_test = np.empty((n + 1, n_ext_update + 1)) print("\n\n\n") # > TRAINING # #> Single node training print("--- SINGLE NODE TRAINING ---\n") single_node.local_update(n_loc_update, lr) losses_train[0, 0] = mse_loss(train_data, single_node.w) losses_test[0, 0] = mse_loss(test_data, single_node.w) for j in tqdm(range(n_ext_update)): for i in range(single_times): single_node.local_update(n_loc_update) losses_train[0, j + 1] = mse_loss(train_data, single_node.w) losses_test[0, j + 1] = mse_loss(test_data, single_node.w) # #> Multiple nodes training print(f"\n--- MULTIPLE ({n}) NODES TRAINING ---\n") # ##> first local update outputs = trainer( nodes=nodes, n_iters=repeat(n_loc_update, n), lrs=repeat(lr, n), eval_data=repeat([train_data, test_data], n), loss_fn=mse_loss, ) for i, (loss_train, loss_test) in enumerate(outputs): losses_train[i + 1, 0] = loss_train losses_test[i + 1, 0] = loss_test ave = defaultdict(lambda: defaultdict(lambda: 0)) for j in tqdm(range(n_ext_update)): # ##> external averaging _i = np.random.randint(n) # selected node _i for external update available_neighbors = np.nonzero(P[_i])[0].shape[0] num_of_neighbor = min( n_nodes_ext_ave - 1, available_neighbors ) # take all neighbors if less than requested _j_selector = np.random.choice( n, size=num_of_neighbor, p=P[_i], replace=False ) # select the nodes _j s from node _i's neighbor for s in _j_selector: ave[_i][s] += 1 nodes[_i].external_update(*nodes[_j_selector]) # ###b> local update outputs = trainer( nodes=nodes, n_iters=repeat(n_loc_update, n), lrs=repeat(lr, n), eval_data=repeat([train_data, test_data], n), loss_fn=mse_loss, ) for i, (loss_train, loss_test) in enumerate(outputs): losses_train[i + 1, j + 1] = loss_train losses_test[i + 1, j + 1] = loss_test print(f"number of time a node made an external update :") for _i, _js in sorted(ave.items(), key=lambda x: (-sum(x[1].values()), x[0])): _total = 0 print(f"\n\tNode_{_i+1} ->") for _j, _val in sorted(_js.items(), key=lambda x: -x[1]): _total += _val print(f"\t\tNode_{_j+1} : {_val} times") print(f"\t\tNode_{_i+1} contacts its neighbors {_total} times.") # > INSTRUCTIONS print( f"""" --- INSTRUCTION FOR PLOT FUNCTION --- >>> plot(1,3,5) # plot nodes 1,3,5 evaluations. It can take as many value as nodes available # note that 0 stands for single node. >>> plot(start=2,stop=5) # plot evaluations for node 2,3,4,5 i.e. from start to stop (inclusive) >>> plot(start= 3)# will not do anything unless stop is specified. start default value is 1. >>> plot(stop=5) # plot evaluations for node 1,2,3,4,5 i.e. from 1 to stop (inclusive) >>> plot(stop=-1) # plot evaluations for each node start to node LAST i.e. ALL nodes >>> plot(start=2, stop=2) # plot evaluations for SECOND node to SECOND LAST node. >>> plot(ave = False) # hide the plot of average of the nodes >>> plot(single = False) # hide the plot of single node >>> plot(train = False) # hide evaluation with train data >>> plot(test = False) # hide evaluation with test data USER CAN USE ANY COMBINATION OF THESE BUT FOLLOWING THIS ORDER IS RECOMMENDED. eg : >>> plot(1,2,3, ave=False, single=False, train=False, test=True) # show eval of node 1, node2 and node3 on test data >>> plot(start=1,stop=3, ave=False, single=False, train=False) # the same as above NB : single node got trained {single_times} x more than others node before evaluations. """ ) # > SETUPS FOR PLOTTING title = f""" Value of $f$ for a local update per iteration nb of nodes = {n} batch size = {batch_size} nb iter for loc update = {n_loc_update} learning rate {lr} """ def plot( *indices, start=1, stop=None, single=True, ave=True, train=True, test=True, ): if train: custom_plot( losses_train, *indices, i_start=start, i_stop=stop, single=single, ave=ave, pref="train ", style1="-", style2="--", ) if test: custom_plot( losses_test, *indices, i_start=start, i_stop=stop, single=single, ave=ave, pref="test ", style1="-.", style2=":", ) # ##> LEGEND plt.xlabel("iteration") plt.legend() plt.title(title) plt.show() def plot_graph(show_contact=True): avg = None if show_contact: avg = ave graph(P, ave=avg) plt.show() plot(stop=-1) plot_graph()