import numpy as np from flwr.server.strategy import FedAvg from flwr.server.client_proxy import ClientProxy from typing import Callable, Dict, List, Optional, Tuple, Union from flwr.common import (FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar, MetricsAggregationFn) from flwr.common.logger import log from logging import WARNING, INFO import utils WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ Setting `min_available_clients` lower than `min_fit_clients` or `min_evaluate_clients` can cause the server to fail when there are too few clients connected to the server. `min_available_clients` must be set to a value larger than or equal to the values of `min_fit_clients` and `min_evaluate_clients`. """ # Flower strategy that picks the best model based on accuracy and loss function def flower_strategy(models, accuracy_scores, loss_scores): # Calculate the average accuracy and loss scores for each model avg_accuracy = np.mean(accuracy_scores) avg_loss = np.mean(loss_scores) # Find the index of the model with the highest accuracy score best_accuracy_index = np.argmax(accuracy_scores) # Find the index of the model with the lowest loss score best_loss_index = np.argmin(loss_scores) # Check if the model with the highest accuracy score also has the lowest loss score if best_accuracy_index == best_loss_index: return models[best_accuracy_index] else: # Calculate the weighted score for each model based on accuracy and loss weighted_scores = (accuracy_scores - avg_accuracy) + \ (avg_loss - loss_scores) # Find the index of the model with the highest weighted score best_weighted_index = np.argmax(weighted_scores) return models[best_weighted_index] class FedEnsemble(FedAvg): def __init__( self, *, fraction_fit: float = 1.0, fraction_evaluate: float = 1.0, min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]], ] ] = None, on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[ int], Dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, num_rounds: int = 10, exp_id: str = "test", problem_type: str ) -> None: if ( min_fit_clients > min_available_clients or min_evaluate_clients > min_available_clients ): log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW) self.num_rounds = num_rounds self.exp_id = exp_id self.problem_type = problem_type super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, min_fit_clients=min_fit_clients, min_evaluate_clients=min_evaluate_clients, min_available_clients=min_available_clients, evaluate_fn=evaluate_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=accept_failures, initial_parameters=initial_parameters, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, ) def __repr__(self) -> str: return "FedEnsemble" def aggregate_fit( self, server_round: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate training results. In this strategy we choose the weights with the best accuracy and loss scores.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} def get_best(result: FitRes): # Get the best model based loss scores return utils.get_loss(problem_type=self.problem_type, metrics=result.metrics) aggregated_weight_results = utils.custom_parameters_to_ndarrays( min([fit_res for _, fit_res in results], key=get_best).parameters) # Save aggregated_ndarrays on the last round if (server_round == self.num_rounds): # create weights folder if not exist import os if not os.path.exists('./weights'): os.makedirs('./weights') log(INFO, f"Saving round {server_round} aggregated_ndarrays...") np.savez(f"./weights/{self.exp_id}-weights.npz", *aggregated_weight_results) parameters_aggregated = utils.ndarrays_to_custom_parameters( aggregated_weight_results) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.fit_metrics_aggregation_fn: fit_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) elif server_round == 1: # Only log this warning once log(WARNING, "No fit_metrics_aggregation_fn provided") return parameters_aggregated, metrics_aggregated def evaluate( self, server_round: int, parameters: Parameters ) -> Optional[Tuple[float, Dict[str, Scalar]]]: """Evaluate the current model parameters. Evaluate model parameters using an evaluation function.""" if self.evaluate_fn is None: # No evaluation function provided return None # We deserialize using our custom method parameters_ndarrays = utils.custom_parameters_to_ndarrays(parameters) eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {}) if eval_res is None: return None loss, metrics = eval_res return loss, metrics