import argparse
import warnings
# from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment
import numpy as np
import flwr as fl
import utils
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Status,
)
from typing import List
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
type=int,
required=True,
help="Specifies the artificial data partition",
)
parser.add_argument(
"--num-clients",
type=int,
required=True,
help="Specifies the number of clients",
)
parser.add_argument(
"--problem-type",
type=str,
required=True,
choices=["classification", "regression"],
help="Specifies the problem type",
default="classification"
)
parser.add_argument(
"--data-path",
type=str,
required=True,
help="Specifies the path to the data",
)
parser.add_argument(
"--target",
required=False,
help="Specifies the target column",
default=-1
)
args = parser.parse_args()
partition_id = args.node_id
problem_type = args.problem_type
N_CLIENTS = args.num_clients
target = args.target
# Setup an experiment
df = utils.load_data(args.data_path, target_column=target)
# Partition the data based on NUM_CLIENTS and CLIENT_ID
# df = utils.stratified_partition_with_all_values(
# df=df, n_partitions=N_CLIENTS, partition_id=partition_id, target=target)
target = utils.get_target_column_name(df, target)
if (problem_type == 'classification'):
exp = ClassificationExperiment()
else:
exp = RegressionExperiment()
exp.setup(data=df, target=target, **utils.setup_params)
if (problem_type == 'classification'):
exp.add_metric('logloss', 'Log Loss', log_loss,
greater_is_better=False, target="pred_proba")
X_train = exp.get_config("X_train")
y_train = exp.get_config("y_train")
X_test = exp.get_config("X_test")
y_test = exp.get_config("y_test")
# Define Flower client
class FlowerClient(fl.client.Client):
def __init__(self, cid):
self.model = None
self.cid = cid
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
print(f"[Client {self.cid}] get_parameters")
# Get parameters as a list of NumPy ndarray's
ndarrays: List[np.ndarray] = utils.get_model_parameters(self.model)
# Serialize ndarray's into a Parameters object
parameters = utils.ndarrays_to_custom_parameters(ndarrays)
# Build and return response
status = Status(code=Code.OK, message="Success")
return GetParametersRes(
status=status,
parameters=parameters,
)
def fit(self, ins: FitIns) -> FitRes:
print(f"[Client {self.cid}] fit, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's
parameters_original = ins.parameters
ndarrays_original = utils.custom_parameters_to_ndarrays(
parameters_original)
# Update local model, train, get updated parameters
if (self.model is None):
print(f"Creating model with {len(X_train)} examples")
self.model = exp.create_model(
ins.config['model_name'], cross_validation=False, train_model=False)
utils.set_model_params(self.model, ndarrays_original)
# Ignore convergence failure due to low local epochs
with warnings.catch_warnings():
print(f"Training model with {len(X_train)} examples")
warnings.simplefilter("ignore")
# self.model.fit(X_train, y_train)
self.model = exp.create_model(
self.model, cross_validation=False)
metrics = exp.pull()
results = utils.get_metrics(metrics, problem_type=problem_type)
ndarrays_updated = utils.get_model_parameters(self.model)
# Serialize ndarray's into a Parameters object
parameters_updated = utils.ndarrays_to_custom_parameters(
ndarrays_updated)
# Build and return response
status = Status(code=Code.OK, message="Success")
return FitRes(
status=status,
parameters=parameters_updated,
num_examples=len(X_train),
metrics=results,
)
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
print(f"[Client {self.cid}] evaluate, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's
parameters_original = ins.parameters
ndarrays_original = utils.custom_parameters_to_ndarrays(
parameters_original)
utils.set_model_params(self.model, ndarrays_original)
exp.predict_model(self.model)
metrics = exp.pull()
results = utils.get_metrics(metrics, problem_type=problem_type)
loss = utils.get_loss(problem_type=problem_type, metrics=metrics)
# Build and return response
status = Status(code=Code.OK, message="Success")
return EvaluateRes(
status=status,
loss=loss,
num_examples=len(X_test),
metrics=results,
)
# Start Flower client
fl.client.start_client(
server_address="0.0.0.0:8091", client=FlowerClient(cid=partition_id)
)