import torch
import numpy as np
import random
import networkx as nx
from dtaidistance import dtw
from sklearn.metrics import f1_score, precision_score, recall_score,roc_auc_score
class Server():
def __init__(self, model, dataLoader,device):
self.model = model.to(device)
self.W = {key: value for key, value in self.model.named_parameters()}
self.model_cache = []
self.dataLoader = dataLoader
self.device = device
def randomSample_clients(self, all_clients, frac):
return random.sample(all_clients, int(len(all_clients) * frac))
def aggregate_weights(self, selected_clients):
# pass train_size, and weighted aggregate
total_size = 0
for client in selected_clients:
total_size += client.train_size
for k in self.W.keys():
self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()
def aggregate_weights_per(self, selected_clients):
# pass train_size, and weighted aggregate
total_size = 0
for client in selected_clients:
total_size += client.train_size
for k in self.W.keys():
if 'graph_convs' in k:
self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()
def aggregate_weights_se(self, selected_clients):
# pass train_size, and weighted aggregate
total_size = 0
for client in selected_clients:
total_size += client.train_size
for k in self.W.keys():
if '_s' in k:
self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()
def aggregate_weights_fe(self, selected_clients):
# pass train_size, and weighted aggregate
total_size = 0
for client in selected_clients:
total_size += client.train_size
for k in self.W.keys():
if '_s' not in k:
self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()
def compute_pairwise_similarities(self, clients):
client_dWs = []
for client in clients:
dW = {}
for k in self.W.keys():
dW[k] = client.dW[k]
client_dWs.append(dW)
return pairwise_angles(client_dWs)
def compute_pairwise_distances(self, seqs, standardize=False):
""" computes DTW distances """
if standardize:
# standardize to only focus on the trends
seqs = np.array(seqs)
seqs = seqs / seqs.std(axis=1).reshape(-1, 1)
distances = dtw.distance_matrix(seqs)
else:
distances = dtw.distance_matrix(seqs)
return distances
def min_cut(self, similarity, idc):
g = nx.Graph()
for i in range(len(similarity)):
for j in range(len(similarity)):
g.add_edge(i, j, weight=similarity[i][j])
cut, partition = nx.stoer_wagner(g)
c1 = np.array([idc[x] for x in partition[0]])
c2 = np.array([idc[x] for x in partition[1]])
return c1, c2
def aggregate_clusterwise(self, client_clusters):
for cluster in client_clusters:
targs = []
sours = []
total_size = 0
for client in cluster:
W = {}
dW = {}
for k in self.W.keys():
W[k] = client.W[k]
dW[k] = client.dW[k]
targs.append(W)
sours.append((dW, client.train_size))
total_size += client.train_size
# pass train_size, and weighted aggregate
reduce_add_average(targets=targs, sources=sours, total_size=total_size)
def compute_max_update_norm(self, cluster):
max_dW = -np.inf
for client in cluster:
dW = {}
for k in self.W.keys():
dW[k] = client.dW[k]
update_norm = torch.norm(flatten(dW)).item()
if update_norm > max_dW:
max_dW = update_norm
return max_dW
# return np.max([torch.norm(flatten(client.dW)).item() for client in cluster])
def compute_mean_update_norm(self, cluster):
cluster_dWs = []
for client in cluster:
dW = {}
for k in self.W.keys():
dW[k] = client.dW[k]
cluster_dWs.append(flatten(dW))
return torch.norm(torch.mean(torch.stack(cluster_dWs), dim=0)).item()
def cache_model(self, idcs, params, accuracies):
self.model_cache += [(idcs,
{name: params[name].data.clone() for name in params},
[accuracies[i] for i in idcs])]
def evaluate_(self):
return eval_gc_(self.model, self.dataLoader['test'], self.device)
def evaluate_loss(self):
return eval_gc_loss(self.model, self.dataLoader['test'], self.device)
def evaluate_loss_(self):
return eval_gc_loss_(self.model, self.dataLoader['test'], self.device)
def flatten(source):
return torch.cat([value.flatten() for value in source.values()])
def pairwise_angles(sources):
angles = torch.zeros([len(sources), len(sources)])
for i, source1 in enumerate(sources):
for j, source2 in enumerate(sources):
s1 = flatten(source1)
s2 = flatten(source2)
angles[i, j] = torch.true_divide(torch.sum(s1 * s2), max(torch.norm(s1) * torch.norm(s2), 1e-12)) + 1
return angles.numpy()
def reduce_add_average(targets, sources, total_size):
for target in targets:
for name in target:
tmp = torch.div(torch.sum(torch.stack([torch.mul(source[0][name].data, source[1]) for source in sources]), dim=0), total_size).clone()
target[name].data += tmp
def eval_gc_(model, test_loader, device):
model.eval()
total_loss = 0.
acc_sum = 0.
ngraphs = 0
for batch in test_loader:
batch.to(device)
with torch.no_grad():
out = model(batch)
label = batch.y
loss = model.loss(out, label)
pred = out.argmax(dim=1)
# print(loss.item())
total_loss += loss.item() * batch.num_graphs
acc_sum += out.max(dim=1)[1].eq(label).sum().item()
ngraphs += batch.num_graphs
f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
zero_division=1)
precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
average='macro', zero_division=1)
recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
average='macro', zero_division=1)
labels_list = label.detach().cpu().numpy()
auc = roc_auc_score(np.eye(len(out[0].detach().cpu().numpy()))[labels_list], out.detach().cpu().numpy()[:,0:(max(labels_list) + 1)])
# auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy())
return total_loss / ngraphs, acc_sum / ngraphs, f1, precision, recall, auc
def eval_gc_loss(model, test_loader, device):
model.eval()
total_loss = 0.
acc_sum = 0.
ngraphs = 0
for batch in test_loader:
# batch.stc_enc = batch.stc_enc
# print(batch)
batch.to(device)
with torch.no_grad():
out = model(batch)
label = batch.y
loss = model.loss(out, label)
pred = out.argmax(dim=1)
# print(loss.item())
total_loss += loss.item() * batch.num_graphs
acc_sum += out.max(dim=1)[1].eq(label).sum().item()
ngraphs += batch.num_graphs
f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
zero_division=1)
precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
average='macro', zero_division=1)
recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
average='macro', zero_division=1)
labels_list = label.detach().cpu().numpy()
auc = roc_auc_score(np.eye(len(out[0].detach().cpu().numpy()))[labels_list], out.detach().cpu().numpy()[:,0:(max(labels_list) + 1)])
# auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy())
return total_loss / ngraphs, acc_sum / ngraphs, f1, precision, recall, auc
def eval_gc_loss_(model, test_loader, device):
model.eval()
total_loss = 0.
acc_sum = 0.
ngraphs = 0
for batch in test_loader:
print(batch)
batch.x = batch.x[:, 0:3]
print(batch)
# batch.stc_enc = batch.stc_enc
# print(batch)
batch.to(device)
with torch.no_grad():
out = model(batch)
label = batch.y
loss = model.loss(out, label)
pred = out.argmax(dim=1)
# print(loss.item())
total_loss += loss.item() * batch.num_graphs
acc_sum += out.max(dim=1)[1].eq(label).sum().item()
ngraphs += batch.num_graphs
f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
zero_division=1)
precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
average='macro', zero_division=1)
recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
average='macro', zero_division=1)
labels_list = label.detach().cpu().numpy()
auc = roc_auc_score(np.eye(len(out[0].detach().cpu().numpy()))[labels_list], out.detach().cpu().numpy()[:,0:(max(labels_list) + 1)])
# auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy())
return total_loss / ngraphs, acc_sum / ngraphs, f1, precision, recall, auc