import argparse import warnings # from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss from pycaret.classification import ClassificationExperiment from pycaret.regression import RegressionExperiment import numpy as np import flwr as fl import utils from flwr.common import ( Code, EvaluateIns, EvaluateRes, FitIns, FitRes, GetParametersIns, GetParametersRes, Status, ) from typing import List if __name__ == "__main__": parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--node-id", type=int, required=True, help="Specifies the artificial data partition", ) parser.add_argument( "--num-clients", type=int, required=True, help="Specifies the number of clients", ) parser.add_argument( "--problem-type", type=str, required=True, choices=["classification", "regression"], help="Specifies the problem type", default="classification" ) parser.add_argument( "--data-path", type=str, required=True, help="Specifies the path to the data", ) parser.add_argument( "--target", required=False, help="Specifies the target column", default=-1 ) args = parser.parse_args() partition_id = args.node_id problem_type = args.problem_type N_CLIENTS = args.num_clients target = args.target # Setup an experiment df = utils.load_data(args.data_path, target_column=target) # Partition the data based on NUM_CLIENTS and CLIENT_ID # df = utils.stratified_partition_with_all_values( # df=df, n_partitions=N_CLIENTS, partition_id=partition_id, target=target) target = utils.get_target_column_name(df, target) if (problem_type == 'classification'): exp = ClassificationExperiment() else: exp = RegressionExperiment() exp.setup(data=df, target=target, **utils.setup_params) if (problem_type == 'classification'): exp.add_metric('logloss', 'Log Loss', log_loss, greater_is_better=False, target="pred_proba") X_train = exp.get_config("X_train") y_train = exp.get_config("y_train") X_test = exp.get_config("X_test") y_test = exp.get_config("y_test") # Define Flower client class FlowerClient(fl.client.Client): def __init__(self, cid): self.model = None self.cid = cid def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: print(f"[Client {self.cid}] get_parameters") # Get parameters as a list of NumPy ndarray's ndarrays: List[np.ndarray] = utils.get_model_parameters(self.model) # Serialize ndarray's into a Parameters object parameters = utils.ndarrays_to_custom_parameters(ndarrays) # Build and return response status = Status(code=Code.OK, message="Success") return GetParametersRes( status=status, parameters=parameters, ) def fit(self, ins: FitIns) -> FitRes: print(f"[Client {self.cid}] fit, config: {ins.config}") # Deserialize parameters to NumPy ndarray's parameters_original = ins.parameters ndarrays_original = utils.custom_parameters_to_ndarrays( parameters_original) # Update local model, train, get updated parameters if (self.model is None): print(f"Creating model with {len(X_train)} examples") self.model = exp.create_model( ins.config['model_name'], cross_validation=False, train_model=False) utils.set_model_params(self.model, ndarrays_original) # Ignore convergence failure due to low local epochs with warnings.catch_warnings(): print(f"Training model with {len(X_train)} examples") warnings.simplefilter("ignore") # self.model.fit(X_train, y_train) self.model = exp.create_model( self.model, cross_validation=False) metrics = exp.pull() results = utils.get_metrics(metrics, problem_type=problem_type) ndarrays_updated = utils.get_model_parameters(self.model) # Serialize ndarray's into a Parameters object parameters_updated = utils.ndarrays_to_custom_parameters( ndarrays_updated) # Build and return response status = Status(code=Code.OK, message="Success") return FitRes( status=status, parameters=parameters_updated, num_examples=len(X_train), metrics=results, ) def evaluate(self, ins: EvaluateIns) -> EvaluateRes: print(f"[Client {self.cid}] evaluate, config: {ins.config}") # Deserialize parameters to NumPy ndarray's parameters_original = ins.parameters ndarrays_original = utils.custom_parameters_to_ndarrays( parameters_original) utils.set_model_params(self.model, ndarrays_original) exp.predict_model(self.model) metrics = exp.pull() results = utils.get_metrics(metrics, problem_type=problem_type) loss = utils.get_loss(problem_type=problem_type, metrics=metrics) # Build and return response status = Status(code=Code.OK, message="Success") return EvaluateRes( status=status, loss=loss, num_examples=len(X_test), metrics=results, ) # Start Flower client fl.client.start_client( server_address="0.0.0.0:8091", client=FlowerClient(cid=partition_id) )