auto-fl-fit / pycaret-fl / client.py
client.py
Raw
import argparse
import warnings

# from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment

import numpy as np

import flwr as fl
import utils

from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
)

from typing import List


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Flower")
    parser.add_argument(
        "--node-id",
        type=int,
        required=True,
        help="Specifies the artificial data partition",
    )
    parser.add_argument(
        "--num-clients",
        type=int,
        required=True,
        help="Specifies the number of clients",
    )
    parser.add_argument(
        "--problem-type",
        type=str,
        required=True,
        choices=["classification", "regression"],
        help="Specifies the problem type",
        default="classification"
    )
    parser.add_argument(
        "--data-path",
        type=str,
        required=True,
        help="Specifies the path to the data",
    )
    parser.add_argument(
        "--target",
        required=False,
        help="Specifies the target column",
        default=-1
    )
    args = parser.parse_args()
    partition_id = args.node_id
    problem_type = args.problem_type
    N_CLIENTS = args.num_clients
    target = args.target

# Setup an experiment
    df = utils.load_data(args.data_path, target_column=target)

    # Partition the data based on NUM_CLIENTS and CLIENT_ID
    # df = utils.stratified_partition_with_all_values(
    #     df=df, n_partitions=N_CLIENTS, partition_id=partition_id, target=target)
    target = utils.get_target_column_name(df, target)

    if (problem_type == 'classification'):
        exp = ClassificationExperiment()
    else:
        exp = RegressionExperiment()
    exp.setup(data=df, target=target, **utils.setup_params)
    if (problem_type == 'classification'):
        exp.add_metric('logloss', 'Log Loss', log_loss,
                       greater_is_better=False, target="pred_proba")

    X_train = exp.get_config("X_train")
    y_train = exp.get_config("y_train")
    X_test = exp.get_config("X_test")
    y_test = exp.get_config("y_test")

    # Define Flower client
    class FlowerClient(fl.client.Client):
        def __init__(self, cid):
            self.model = None
            self.cid = cid

        def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
            print(f"[Client {self.cid}] get_parameters")

            # Get parameters as a list of NumPy ndarray's
            ndarrays: List[np.ndarray] = utils.get_model_parameters(self.model)

            # Serialize ndarray's into a Parameters object
            parameters = utils.ndarrays_to_custom_parameters(ndarrays)

            # Build and return response
            status = Status(code=Code.OK, message="Success")
            return GetParametersRes(
                status=status,
                parameters=parameters,
            )

        def fit(self, ins: FitIns) -> FitRes:
            print(f"[Client {self.cid}] fit, config: {ins.config}")

            # Deserialize parameters to NumPy ndarray's
            parameters_original = ins.parameters
            ndarrays_original = utils.custom_parameters_to_ndarrays(
                parameters_original)

            # Update local model, train, get updated parameters
            if (self.model is None):
                print(f"Creating model with {len(X_train)} examples")
                self.model = exp.create_model(
                    ins.config['model_name'], cross_validation=False, train_model=False)
            utils.set_model_params(self.model, ndarrays_original)
            # Ignore convergence failure due to low local epochs
            with warnings.catch_warnings():
                print(f"Training model with {len(X_train)} examples")
                warnings.simplefilter("ignore")
                # self.model.fit(X_train, y_train)
                self.model = exp.create_model(
                    self.model, cross_validation=False)
                metrics = exp.pull()
                results = utils.get_metrics(metrics, problem_type=problem_type)

            ndarrays_updated = utils.get_model_parameters(self.model)

            # Serialize ndarray's into a Parameters object
            parameters_updated = utils.ndarrays_to_custom_parameters(
                ndarrays_updated)

            # Build and return response
            status = Status(code=Code.OK, message="Success")
            return FitRes(
                status=status,
                parameters=parameters_updated,
                num_examples=len(X_train),
                metrics=results,
            )

        def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
            print(f"[Client {self.cid}] evaluate, config: {ins.config}")

            # Deserialize parameters to NumPy ndarray's
            parameters_original = ins.parameters
            ndarrays_original = utils.custom_parameters_to_ndarrays(
                parameters_original)

            utils.set_model_params(self.model, ndarrays_original)
            exp.predict_model(self.model)
            metrics = exp.pull()
            results = utils.get_metrics(metrics, problem_type=problem_type)

            loss = utils.get_loss(problem_type=problem_type, metrics=metrics)

            # Build and return response
            status = Status(code=Code.OK, message="Success")
            return EvaluateRes(
                status=status,
                loss=loss,
                num_examples=len(X_test),
                metrics=results,
            )

    # Start Flower client
    fl.client.start_client(
        server_address="0.0.0.0:8091", client=FlowerClient(cid=partition_id)
    )