import torch import numpy as np import random import networkx as nx from dtaidistance import dtw class Server(): def __init__(self, model, device): self.model = model.to(device) self.W = {key: value for key, value in self.model.named_parameters()} self.model_cache = [] def randomSample_clients(self, all_clients, frac): return random.sample(all_clients, int(len(all_clients) * frac)) def aggregate_weights(self, selected_clients): # pass train_size, and weighted aggregate total_size = 0 for client in selected_clients: total_size += client.train_size for k in self.W.keys(): self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone() def aggregate_weights_per(self, selected_clients): # pass train_size, and weighted aggregate total_size = 0 for client in selected_clients: total_size += client.train_size for k in self.W.keys(): if 'graph_convs' in k: self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone() def aggregate_weights_se(self, selected_clients): # pass train_size, and weighted aggregate total_size = 0 for client in selected_clients: total_size += client.train_size for k in self.W.keys(): if '_s' in k: self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone() def aggregate_weights_fe(self, selected_clients): # pass train_size, and weighted aggregate total_size = 0 for client in selected_clients: total_size += client.train_size for k in self.W.keys(): if '_s' not in k: self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone() def compute_pairwise_similarities(self, clients): client_dWs = [] for client in clients: dW = {} for k in self.W.keys(): dW[k] = client.dW[k] client_dWs.append(dW) return pairwise_angles(client_dWs) def compute_pairwise_distances(self, seqs, standardize=False): """ computes DTW distances """ if standardize: # standardize to only focus on the trends seqs = np.array(seqs) seqs = seqs / seqs.std(axis=1).reshape(-1, 1) distances = dtw.distance_matrix(seqs) else: distances = dtw.distance_matrix(seqs) return distances def min_cut(self, similarity, idc): g = nx.Graph() for i in range(len(similarity)): for j in range(len(similarity)): g.add_edge(i, j, weight=similarity[i][j]) cut, partition = nx.stoer_wagner(g) c1 = np.array([idc[x] for x in partition[0]]) c2 = np.array([idc[x] for x in partition[1]]) return c1, c2 def aggregate_clusterwise(self, client_clusters): for cluster in client_clusters: targs = [] sours = [] total_size = 0 for client in cluster: W = {} dW = {} for k in self.W.keys(): W[k] = client.W[k] dW[k] = client.dW[k] targs.append(W) sours.append((dW, client.train_size)) total_size += client.train_size # pass train_size, and weighted aggregate reduce_add_average(targets=targs, sources=sours, total_size=total_size) def compute_max_update_norm(self, cluster): max_dW = -np.inf for client in cluster: dW = {} for k in self.W.keys(): dW[k] = client.dW[k] update_norm = torch.norm(flatten(dW)).item() if update_norm > max_dW: max_dW = update_norm return max_dW # return np.max([torch.norm(flatten(client.dW)).item() for client in cluster]) def compute_mean_update_norm(self, cluster): cluster_dWs = [] for client in cluster: dW = {} for k in self.W.keys(): dW[k] = client.dW[k] cluster_dWs.append(flatten(dW)) return torch.norm(torch.mean(torch.stack(cluster_dWs), dim=0)).item() def cache_model(self, idcs, params, accuracies): self.model_cache += [(idcs, {name: params[name].data.clone() for name in params}, [accuracies[i] for i in idcs])] def evaluate_(self): return eval_gc_(self.model, self.dataLoader['test'], self.device) def flatten(source): return torch.cat([value.flatten() for value in source.values()]) def pairwise_angles(sources): angles = torch.zeros([len(sources), len(sources)]) for i, source1 in enumerate(sources): for j, source2 in enumerate(sources): s1 = flatten(source1) s2 = flatten(source2) angles[i, j] = torch.true_divide(torch.sum(s1 * s2), max(torch.norm(s1) * torch.norm(s2), 1e-12)) + 1 return angles.numpy() def reduce_add_average(targets, sources, total_size): for target in targets: for name in target: tmp = torch.div(torch.sum(torch.stack([torch.mul(source[0][name].data, source[1]) for source in sources]), dim=0), total_size).clone() target[name].data += tmp def eval_gc_(model, test_loader, device): model.eval() total_loss = 0. acc_sum = 0. ngraphs = 0 for batch in test_loader: batch.to(device) with torch.no_grad(): out = model(batch) label = batch.y loss = model.loss(out, label) pred = out.argmax(dim=1) total_loss += loss.item() * batch.num_graphs acc_sum += out.max(dim=1)[1].eq(label).sum().item() ngraphs += batch.num_graphs f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro', zero_division=1) precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro', zero_division=1) recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro', zero_division=1) labels_list = label.detach().cpu().numpy() auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy()) return total_loss / ngraphs, acc_sum / ngraphs, f1, precision, recall, auc