CeFGC / lib / server_.py
server_.py
Raw
import torch
import numpy as np
import random
import networkx as nx
from dtaidistance import dtw
from sklearn.metrics import f1_score, precision_score, recall_score,roc_auc_score


class Server():
    def __init__(self, model, dataLoader,device):
        self.model = model.to(device)
        self.W = {key: value for key, value in self.model.named_parameters()}
        self.model_cache = []
        self.dataLoader = dataLoader
        self.device = device


    def randomSample_clients(self, all_clients, frac):
        return random.sample(all_clients, int(len(all_clients) * frac))

    def aggregate_weights(self, selected_clients):
        # pass train_size, and weighted aggregate
        total_size = 0
        for client in selected_clients:
            total_size += client.train_size
        for k in self.W.keys():
            self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()

    def aggregate_weights_per(self, selected_clients):
        # pass train_size, and weighted aggregate
        total_size = 0
        for client in selected_clients:
            total_size += client.train_size
        for k in self.W.keys():
            if 'graph_convs' in k:
                self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()

    def aggregate_weights_se(self, selected_clients):
        # pass train_size, and weighted aggregate
        total_size = 0
        for client in selected_clients:
            total_size += client.train_size
        for k in self.W.keys():
            if '_s' in k:
                self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()

    def aggregate_weights_fe(self, selected_clients):
        # pass train_size, and weighted aggregate
        total_size = 0
        for client in selected_clients:
            total_size += client.train_size
        for k in self.W.keys():
            if '_s' not in k:
                self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()


    def compute_pairwise_similarities(self, clients):
        client_dWs = []
        for client in clients:
            dW = {}
            for k in self.W.keys():
                dW[k] = client.dW[k]
            client_dWs.append(dW)
        return pairwise_angles(client_dWs)

    def compute_pairwise_distances(self, seqs, standardize=False):
        """ computes DTW distances """
        if standardize:
            # standardize to only focus on the trends
            seqs = np.array(seqs)
            seqs = seqs / seqs.std(axis=1).reshape(-1, 1)
            distances = dtw.distance_matrix(seqs)
        else:
            distances = dtw.distance_matrix(seqs)
        return distances

    def min_cut(self, similarity, idc):
        g = nx.Graph()
        for i in range(len(similarity)):
            for j in range(len(similarity)):
                g.add_edge(i, j, weight=similarity[i][j])
        cut, partition = nx.stoer_wagner(g)
        c1 = np.array([idc[x] for x in partition[0]])
        c2 = np.array([idc[x] for x in partition[1]])
        return c1, c2

    def aggregate_clusterwise(self, client_clusters):
        for cluster in client_clusters:
            targs = []
            sours = []
            total_size = 0
            for client in cluster:
                W = {}
                dW = {}
                for k in self.W.keys():
                    W[k] = client.W[k]
                    dW[k] = client.dW[k]
                targs.append(W)
                sours.append((dW, client.train_size))
                total_size += client.train_size
            # pass train_size, and weighted aggregate
            reduce_add_average(targets=targs, sources=sours, total_size=total_size)

    def compute_max_update_norm(self, cluster):
        max_dW = -np.inf
        for client in cluster:
            dW = {}
            for k in self.W.keys():
                dW[k] = client.dW[k]
            update_norm = torch.norm(flatten(dW)).item()
            if update_norm > max_dW:
                max_dW = update_norm
        return max_dW
        # return np.max([torch.norm(flatten(client.dW)).item() for client in cluster])

    def compute_mean_update_norm(self, cluster):
        cluster_dWs = []
        for client in cluster:
            dW = {}
            for k in self.W.keys():
                dW[k] = client.dW[k]
            cluster_dWs.append(flatten(dW))

        return torch.norm(torch.mean(torch.stack(cluster_dWs), dim=0)).item()

    def cache_model(self, idcs, params, accuracies):
        self.model_cache += [(idcs,
                              {name: params[name].data.clone() for name in params},
                              [accuracies[i] for i in idcs])]


    def evaluate_(self):
        return eval_gc_(self.model, self.dataLoader['test'], self.device)

    def evaluate_loss(self):
        return eval_gc_loss(self.model, self.dataLoader['test'], self.device)

    def evaluate_loss_(self):
        return eval_gc_loss_(self.model, self.dataLoader['test'], self.device)


def flatten(source):
    return torch.cat([value.flatten() for value in source.values()])

def pairwise_angles(sources):
    angles = torch.zeros([len(sources), len(sources)])
    for i, source1 in enumerate(sources):
        for j, source2 in enumerate(sources):
            s1 = flatten(source1)
            s2 = flatten(source2)
            angles[i, j] = torch.true_divide(torch.sum(s1 * s2), max(torch.norm(s1) * torch.norm(s2), 1e-12)) + 1

    return angles.numpy()

def reduce_add_average(targets, sources, total_size):
    for target in targets:
        for name in target:
            tmp = torch.div(torch.sum(torch.stack([torch.mul(source[0][name].data, source[1]) for source in sources]), dim=0), total_size).clone()
            target[name].data += tmp


def eval_gc_(model, test_loader, device):
    model.eval()

    total_loss = 0.
    acc_sum = 0.
    ngraphs = 0

    for batch in test_loader:
        batch.to(device)
        with torch.no_grad():
            out = model(batch)
            label = batch.y
            loss = model.loss(out, label)
            pred = out.argmax(dim=1)
            # print(loss.item())
        total_loss += loss.item() * batch.num_graphs
        acc_sum += out.max(dim=1)[1].eq(label).sum().item()
        ngraphs += batch.num_graphs

        f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
                      zero_division=1)
        precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
                                    average='macro', zero_division=1)
        recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
                              average='macro', zero_division=1)
        labels_list = label.detach().cpu().numpy()
        auc = roc_auc_score(np.eye(len(out[0].detach().cpu().numpy()))[labels_list], out.detach().cpu().numpy()[:,0:(max(labels_list) + 1)])


        # auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy())

    return total_loss / ngraphs, acc_sum / ngraphs, f1, precision, recall, auc

def eval_gc_loss(model, test_loader, device):
    model.eval()

    total_loss = 0.
    acc_sum = 0.
    ngraphs = 0

    for batch in test_loader:
        # batch.stc_enc = batch.stc_enc
        # print(batch)
        batch.to(device)
        with torch.no_grad():
            out = model(batch)
            label = batch.y
            loss = model.loss(out, label)
            pred = out.argmax(dim=1)
            # print(loss.item())
        total_loss += loss.item() * batch.num_graphs
        acc_sum += out.max(dim=1)[1].eq(label).sum().item()
        ngraphs += batch.num_graphs

        f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
                      zero_division=1)
        precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
                                    average='macro', zero_division=1)
        recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
                              average='macro', zero_division=1)
        labels_list = label.detach().cpu().numpy()
        auc = roc_auc_score(np.eye(len(out[0].detach().cpu().numpy()))[labels_list], out.detach().cpu().numpy()[:,0:(max(labels_list) + 1)])


        # auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy())

    return total_loss / ngraphs, acc_sum / ngraphs, f1, precision, recall, auc


def eval_gc_loss_(model, test_loader, device):
    model.eval()

    total_loss = 0.
    acc_sum = 0.
    ngraphs = 0

    for batch in test_loader:
        print(batch)
        batch.x = batch.x[:, 0:3]
        print(batch)
        # batch.stc_enc = batch.stc_enc
        # print(batch)
        batch.to(device)
        with torch.no_grad():
            out = model(batch)
            label = batch.y
            loss = model.loss(out, label)
            pred = out.argmax(dim=1)
            # print(loss.item())
        total_loss += loss.item() * batch.num_graphs
        acc_sum += out.max(dim=1)[1].eq(label).sum().item()
        ngraphs += batch.num_graphs

        f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
                      zero_division=1)
        precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
                                    average='macro', zero_division=1)
        recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
                              average='macro', zero_division=1)
        labels_list = label.detach().cpu().numpy()
        auc = roc_auc_score(np.eye(len(out[0].detach().cpu().numpy()))[labels_list], out.detach().cpu().numpy()[:,0:(max(labels_list) + 1)])


        # auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy())

    return total_loss / ngraphs, acc_sum / ngraphs, f1, precision, recall, auc