CeFGC / lib / training_.py
training_.py
Raw
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time

def run_selftrain_GC(args, clients, server, local_epoch):
    # all clients are initialized with the same weights
    for client in clients:
        client.download_from_server(args, server)

    allAccs = {}
    for client in clients:
        client.local_train(local_epoch)

        loss, acc = client.evaluate()
        allAccs[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc]
        print("  > {} done.".format(client.name))

    return allAccs


def run_ours(args, clients, server, local_epoch):
    # all clients are initialized with the same weights
    for client in clients:
        client.download_from_server(args, server)

    allAccs_local = {}
    for client in clients:
        client.local_train(local_epoch)

        loss, acc = client.evaluate()
        allAccs_local[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc,loss]
        print("  > {} done.".format(client.name))

    server.aggregate_weights(clients)
    for client in clients:
        client.download_from_server(args, server)

    allAccs_global={}

    for idx in range(len(clients)):
        client = clients[idx]
        loss, acc = clients[idx].evaluate()
        allAccs_global[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc,loss]
    return allAccs_local,allAccs_global


def run_ours2(args, clients, server, local_epoch):
    # all clients are initialized with the same weights
    for client in clients:
        client.download_from_server(args, server)

    allAccs_local = {}
    for client in clients:
        client.local_train(local_epoch)

        total_loss, acc_sum,f1,precision,recall,auc= client.evaluate_()
        allAccs_local[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], total_loss, acc_sum,f1,precision,recall,auc]
        print("  > {} done.".format(client.name))

    server.aggregate_weights(clients)
    for client in clients:
        client.download_from_server(args, server)

    allAccs_global={}

    for idx in range(len(clients)):
        client = clients[idx]
        total_loss, acc_sum,f1,precision,recall,auc= clients[idx].evaluate_()
        allAccs_global[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], total_loss, acc_sum,f1,precision,recall,auc]
    total_loss, acc_sum,f1,precision,recall,auc=server.evaluate_()
    allAccs_global['global'] = [total_loss, total_loss,total_loss, acc_sum, f1, precision, recall, auc]

    return allAccs_local,allAccs_global


def run_ours_time(args, clients, server, local_epoch):
    # all clients are initialized with the same weights
    start_t = time.time()
    # start_t1 = time.time()
    for client in clients:
        client.download_from_server(args, server)
    # end_t1 = time.time()
    # time_dur1 = end_t1 - start_t1
    # print('time_dur1', time_dur1)

    # exit()

    allAccs_local = {}

    # start_t1 = time.time()

    for client in clients:
        client.local_train_time(local_epoch)
        # loss, acc = client.evaluate()
        # allAccs_local[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc,loss]
        print("  > {} done.".format(client.name))

    # end_t1 = time.time()
    # time_dur1 = end_t1 - start_t1
    # print(time_dur1)
    # exit()
    server.aggregate_weights(clients)
    for client in clients:
        client.download_from_server(args, server)

    # allAccs_global={}
    #
    # for idx in range(len(clients)):
    #     client = clients[idx]
    #     loss, acc = clients[idx].evaluate()
    #     allAccs_global[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc,loss]
    end_t = time.time()
    time_dur = end_t - start_t
    print('time_dur', time_dur)
    # exit()
    # return allAccs_local


def run_ours_loss(args, clients, server, local_epoch):
    # all clients are initialized with the same weights
    for client in clients:
        client.download_from_server(args, server)

    allAccs_local = {}
    for client in clients:
        client.local_train(local_epoch)

        loss, acc = client.evaluate()
        allAccs_local[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc,loss]
        print("  > {} done.".format(client.name))

    server.aggregate_weights(clients)
    for client in clients:
        client.download_from_server(args, server)

    allAccs_global={}
    allLosses = {}
    losses=[]

    for idx in range(len(clients)):
        client=clients[idx]
        loss, acc = clients[idx].evaluate()
        allAccs_global[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc,loss]


        allLosses[client.name] = [client.train_stats['trainingLosses'][-1], client.train_stats['valLosses'][-1],
                                  client.train_stats['testLosses'][-1]]
        losses.append(client.train_stats['trainingLosses'])


    frame = pd.DataFrame()
    for client in clients:
        loss, acc, f1, precision, recall, auc = client.evaluate_()
        frame.loc[client.name, 'test_acc'] = acc
        frame.loc[client.name, 'test_loss'] = loss
        frame.loc[client.name, 'f1'] = f1
        frame.loc[client.name, 'precision'] = precision
        frame.loc[client.name, 'recall'] = recall
        frame.loc[client.name, 'auc'] = auc

    return frame,allLosses,losses



def run_fedavg(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
    for client in clients:
        client.download_from_server(args, server)

    if samp is None:
        sampling_fn = server.randomSample_clients
        frac = 1.0

    for c_round in range(1, COMMUNICATION_ROUNDS + 1):
        if (c_round) % 50 == 0:
            print(f"  > round {c_round}")

        if c_round == 1:
            selected_clients = clients
        else:
            selected_clients = sampling_fn(clients, frac)

        for client in selected_clients:
            # only get weights of graphconv layers
            client.local_train(local_epoch)

        server.aggregate_weights(selected_clients)
        for client in selected_clients:
            client.download_from_server(args, server)

        # write to log files
        if c_round % 5 == 0:
            for idx in range(len(clients)):
                loss, acc = clients[idx].evaluate()
                summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round)
                summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round)

    frame = pd.DataFrame()
    for client in clients:
        loss, acc = client.evaluate()
        frame.loc[client.name, 'test_acc'] = acc

    def highlight_max(s):
        is_max = s == s.max()
        return ['background-color: yellow' if v else '' for v in is_max]

    fs = frame.style.apply(highlight_max).data
    print(fs)
    return frame


def run_fedstar(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
    for client in clients:
        client.download_from_server(args, server)

    if samp is None:
        sampling_fn = server.randomSample_clients
        frac = 1.0

    for c_round in range(1, COMMUNICATION_ROUNDS + 1):
        if (c_round) % 50 == 0:
            print(f"  > round {c_round}")

        if c_round == 1:
            selected_clients = clients
        else:
            selected_clients = sampling_fn(clients, frac)
        start_t1 = time.time()
        for client in selected_clients:
            # only get weights of graphconv layers
            client.local_train(local_epoch)
        end_t1 = time.time()
        time_dur1 = end_t1 - start_t1
        print(time_dur1)
        exit()
        server.aggregate_weights_se(selected_clients)
        for client in selected_clients:
            client.download_from_server(args, server)

        # write to log files
        if c_round % 5 == 0:
            for idx in range(len(clients)):
                loss, acc,f1,precision,recall,auc = clients[idx].evaluate_()
                summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round)
                summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round)
                summary_writer.add_scalar('Test/f1/user' + str(idx + 1), f1, c_round)
                summary_writer.add_scalar('Test/precision/user' + str(idx + 1), precision, c_round)
                summary_writer.add_scalar('Test/recall/user' + str(idx + 1), recall, c_round)
                summary_writer.add_scalar('Test/auc/user' + str(idx + 1), auc, c_round)

    frame = pd.DataFrame()
    for client in clients:
        loss, acc,f1,precision,recall,auc = client.evaluate_()
        frame.loc[client.name, 'test_acc'] = acc
        frame.loc[client.name, 'test_loss'] = loss
        frame.loc[client.name, 'f1'] = f1
        frame.loc[client.name, 'precision'] = precision
        frame.loc[client.name, 'recall'] = recall
        frame.loc[client.name, 'auc'] = auc

    def highlight_max(s):
        is_max = s == s.max()
        return ['background-color: yellow' if v else '' for v in is_max]

    fs = frame.style.apply(highlight_max).data
    print(fs)


    loss, acc, f1, precision, recall, auc = server.evaluate_()

    res_glob = np.array([acc,loss, f1, precision, recall, auc])
    frame_glob = pd.DataFrame(res_glob,columns=['global']).T
    frame_glob.columns = ['test_acc', 'test_loss','f1', 'precision', 'recall', 'auc']

    # print(frame_glob)
    frame = pd.concat([frame, frame_glob], axis=0)
    print('***',frame)

    # frame.columns = ['test_acc', 'test_loss','test_f1', 'test_precision', 'test_recall', 'test_auc']
    # print('@@@@@',frame)
    return frame

def run_fedstar_time(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
    for client in clients:
        client.download_from_server(args, server)

    if samp is None:
        sampling_fn = server.randomSample_clients
        frac = 1.0

    for c_round in range(1, COMMUNICATION_ROUNDS + 1):
        if (c_round) % 50 == 0:
            print(f"  > round {c_round}")

        if c_round == 1:
            selected_clients = clients
        else:
            selected_clients = sampling_fn(clients, frac)

        for client in selected_clients:
            # only get weights of graphconv layers
            client.local_train(local_epoch)

        server.aggregate_weights_se(selected_clients)
        for client in selected_clients:
            client.download_from_server(args, server)

        # write to log files
        if c_round % 5 == 0:
            for idx in range(len(clients)):
                loss, acc,f1,precision,recall,auc = clients[idx].evaluate_()
                summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round)
                summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round)
                summary_writer.add_scalar('Test/f1/user' + str(idx + 1), f1, c_round)
                summary_writer.add_scalar('Test/precision/user' + str(idx + 1), precision, c_round)
                summary_writer.add_scalar('Test/recall/user' + str(idx + 1), recall, c_round)
                summary_writer.add_scalar('Test/auc/user' + str(idx + 1), auc, c_round)

    frame = pd.DataFrame()
    for client in clients:
        loss, acc,f1,precision,recall,auc = client.evaluate_()
        frame.loc[client.name, 'test_acc'] = acc
        frame.loc[client.name, 'test_loss'] = loss
        frame.loc[client.name, 'f1'] = f1
        frame.loc[client.name, 'precision'] = precision
        frame.loc[client.name, 'recall'] = recall
        frame.loc[client.name, 'auc'] = auc

    def highlight_max(s):
        is_max = s == s.max()
        return ['background-color: yellow' if v else '' for v in is_max]

    fs = frame.style.apply(highlight_max).data
    print(fs)

    #
    # loss, acc, f1, precision, recall, auc = server.evaluate_()
    #
    # res_glob = np.array([acc,loss, f1, precision, recall, auc])
    # frame_glob = pd.DataFrame(res_glob,columns=['global']).T
    # frame_glob.columns = ['test_acc', 'test_loss','f1', 'precision', 'recall', 'auc']
    #
    # # print(frame_glob)
    # frame = pd.concat([frame, frame_glob], axis=0)
    # print('***',frame)
    #
    # # frame.columns = ['test_acc', 'test_loss','test_f1', 'test_precision', 'test_recall', 'test_auc']
    # # print('@@@@@',frame)
    return frame


def run_fedstar_(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
    for client in clients:
        client.download_from_server(args, server)

    if samp is None:
        sampling_fn = server.randomSample_clients
        frac = 1.0

    for c_round in range(1, COMMUNICATION_ROUNDS + 1):
        if (c_round) % 50 == 0:
            print(f"  > round {c_round}")

        if c_round == 1:
            selected_clients = clients
        else:
            selected_clients = sampling_fn(clients, frac)

        start_t1 = time.time()

        for client in selected_clients:
            # only get weights of graphconv layers
            client.local_train(local_epoch)

        end_t1 = time.time()
        time_dur1 = end_t1 - start_t1
        print(time_dur1)
        exit()

        server.aggregate_weights_se(selected_clients)
        for client in selected_clients:
            client.download_from_server(args, server)

        # write to log files
        if c_round % 5 == 0:
            for idx in range(len(clients)):
                loss, acc,f1,precision,recall,auc = clients[idx].evaluate_()
                summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round)
                summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round)
                summary_writer.add_scalar('Test/f1/user' + str(idx + 1), f1, c_round)
                summary_writer.add_scalar('Test/precision/user' + str(idx + 1), precision, c_round)
                summary_writer.add_scalar('Test/recall/user' + str(idx + 1), recall, c_round)
                summary_writer.add_scalar('Test/auc/user' + str(idx + 1), auc, c_round)

    frame = pd.DataFrame()
    for client in clients:
        loss, acc,f1,precision,recall,auc = client.evaluate_()
        frame.loc[client.name, 'test_acc'] = acc
        frame.loc[client.name, 'test_loss'] = loss
        frame.loc[client.name, 'f1'] = f1
        frame.loc[client.name, 'precision'] = precision
        frame.loc[client.name, 'recall'] = recall
        frame.loc[client.name, 'auc'] = auc

    def highlight_max(s):
        is_max = s == s.max()
        return ['background-color: yellow' if v else '' for v in is_max]

    fs = frame.style.apply(highlight_max).data
    # print(fs)


    loss, acc, f1, precision, recall, auc = server.evaluate_()

    res_glob = np.array([acc,loss, f1, precision, recall, auc])
    frame_glob = pd.DataFrame(res_glob,columns=['global']).T
    frame_glob.columns = ['test_acc', 'test_loss','f1', 'precision', 'recall', 'auc']

    print(frame_glob)
    frame = pd.concat([frame, frame_glob], axis=0)
    print('***',frame)

    frame.columns = ['test_acc', 'test_loss','test_f1', 'test_precision', 'test_recall', 'test_auc']
    print('@@@@@',frame)
    return frame


def run_fedstar_loss(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
    for client in clients:
        client.download_from_server(args, server)

    if samp is None:
        sampling_fn = server.randomSample_clients
        frac = 1.0
    losses = []
    for c_round in range(1, COMMUNICATION_ROUNDS + 1):
        if (c_round) % 50 == 0:
            print(f"  > round {c_round}")

        if c_round == 1:
            selected_clients = clients
        else:
            selected_clients = sampling_fn(clients, frac)

        for client in selected_clients:
            # only get weights of graphconv layers
            client.local_train(local_epoch)

        server.aggregate_weights_se(selected_clients)
        for client in selected_clients:
            client.download_from_server(args, server)

        # write to log files
        allLosses = {}
        for idx in range(len(clients)):
            loss, acc,f1,precision,recall,auc = clients[idx].evaluate_()
            summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round)
            summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round)
            summary_writer.add_scalar('Test/f1/user' + str(idx + 1), f1, c_round)
            summary_writer.add_scalar('Test/precision/user' + str(idx + 1), precision, c_round)
            summary_writer.add_scalar('Test/recall/user' + str(idx + 1), recall, c_round)
            summary_writer.add_scalar('Test/auc/user' + str(idx + 1), auc, c_round)

            client=clients[idx]

            allLosses[client.name] = [client.train_stats['trainingLosses'][-1], client.train_stats['valLosses'][-1],
                                      client.train_stats['testLosses'][-1]]
            losses.append(client.train_stats['trainingLosses'])


    loss_global=[]

    frame = pd.DataFrame()
    for client in clients:
        loss, acc,f1,precision,recall,auc = client.evaluate_()
        frame.loc[client.name, 'test_acc'] = acc
        frame.loc[client.name, 'test_loss'] = loss
        frame.loc[client.name, 'f1'] = f1
        frame.loc[client.name, 'precision'] = precision
        frame.loc[client.name, 'recall'] = recall
        frame.loc[client.name, 'auc'] = auc


    def highlight_max(s):
        is_max = s == s.max()
        return ['background-color: yellow' if v else '' for v in is_max]

    fs = frame.style.apply(highlight_max).data
    print(fs)

    #
    # loss, acc, f1, precision, recall, auc = server.evaluate_()
    #
    # res_glob = np.array([acc,loss, f1, precision, recall, auc])
    # frame_glob = pd.DataFrame(res_glob,columns=['global']).T
    # frame_glob.columns = ['test_acc', 'test_loss','f1', 'precision', 'recall', 'auc']
    #
    # print(frame_glob)
    # frame = pd.concat([frame, frame_glob], axis=0)
    # print('***',frame)

    # frame.columns = ['test_acc', 'test_loss','test_f1', 'test_precision', 'test_recall', 'test_auc']
    print('@@@@@',frame)
    return frame,allLosses,losses



def run_fedstar_loss_(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
    for client in clients:
        client.download_from_server(args, server)

    if samp is None:
        sampling_fn = server.randomSample_clients
        frac = 1.0
    losses=[]
    for c_round in range(1, COMMUNICATION_ROUNDS + 1):
        if (c_round) % 50 == 0:
            print(f"  > round {c_round}")

        if c_round == 1:
            selected_clients = clients
        else:
            selected_clients = sampling_fn(clients, frac)

        for client in selected_clients:
            # only get weights of graphconv layers
            client.local_train_loss(local_epoch)

        server.aggregate_weights_se(selected_clients)
        for client in selected_clients:
            client.download_from_server(args, server)

        # write to log files
        allLosses = {}

        for idx in range(len(clients)):
            loss, acc,f1,precision,recall,auc = clients[idx].evaluate_loss_()
            summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round)
            summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round)
            summary_writer.add_scalar('Test/f1/user' + str(idx + 1), f1, c_round)
            summary_writer.add_scalar('Test/precision/user' + str(idx + 1), precision, c_round)
            summary_writer.add_scalar('Test/recall/user' + str(idx + 1), recall, c_round)
            summary_writer.add_scalar('Test/auc/user' + str(idx + 1), auc, c_round)


            client=clients[idx]

            allLosses[client.name] = [client.train_stats['trainingLosses'][-1], client.train_stats['valLosses'][-1],
                                      client.train_stats['testLosses'][-1]]
            losses.append(client.train_stats['trainingLosses'])


    frame = pd.DataFrame()
    for client in clients:
        loss, acc,f1,precision,recall,auc = client.evaluate_loss_()
        frame.loc[client.name, 'test_acc'] = acc
        frame.loc[client.name, 'test_loss'] = loss
        frame.loc[client.name, 'f1'] = f1
        frame.loc[client.name, 'precision'] = precision
        frame.loc[client.name, 'recall'] = recall
        frame.loc[client.name, 'auc'] = auc

    def highlight_max(s):
        is_max = s == s.max()
        return ['background-color: yellow' if v else '' for v in is_max]

    fs = frame.style.apply(highlight_max).data
    print(fs)


    # loss, acc, f1, precision, recall, auc = server.evaluate_loss_()
    # #
    # res_glob = np.array([acc,loss, f1, precision, recall, auc])
    # frame_glob = pd.DataFrame(res_glob,columns=['global']).T
    # frame_glob.columns = ['test_acc', 'test_loss','f1', 'precision', 'recall', 'auc']
    #
    # print(frame_glob)
    # frame = pd.concat([frame, frame_glob], axis=0)
    # print('***',frame)
    #

    # frame.columns = ['test_acc', 'test_loss','test_f1', 'test_precision', 'test_recall', 'test_auc']
    print('@@@@@',frame)
    return frame,allLosses,losses