from logging import WARNING, INFO from typing import Callable, Dict, List, Optional, Tuple, Union from flwr.common import (FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar) from flwr.common.logger import log from flwr.server.client_proxy import ClientProxy from flwr.server.strategy import FedAvg from flwr.server.strategy.aggregate import aggregate import utils import numpy as np WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ Setting `min_available_clients` lower than `min_fit_clients` or `min_evaluate_clients` can cause the server to fail when there are too few clients connected to the server. `min_available_clients` must be set to a value larger than or equal to the values of `min_fit_clients` and `min_evaluate_clients`. """ class CustomFedAvg(FedAvg): def __init__( self, *, fraction_fit: float = 1.0, fraction_evaluate: float = 1.0, min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]], ] ] = None, on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[ int], Dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, num_rounds: int = 10, exp_id: str = "test" ) -> None: """Custom FedAvg strategy with custom matrices. Parameters ---------- fraction_fit : float, optional Fraction of clients used during training. Defaults to 0.1. fraction_evaluate : float, optional Fraction of clients used during validation. Defaults to 0.1. min_fit_clients : int, optional Minimum number of clients used during training. Defaults to 2. min_evaluate_clients : int, optional Minimum number of clients used during validation. Defaults to 2. min_available_clients : int, optional Minimum number of total clients in the system. Defaults to 2. evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] Optional function used for validation. Defaults to None. on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure training. Defaults to None. on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure validation. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters, optional Initial global model parameters. """ if ( min_fit_clients > min_available_clients or min_evaluate_clients > min_available_clients ): log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW) self.num_rounds = num_rounds self.exp_id = exp_id super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, min_fit_clients=min_fit_clients, min_evaluate_clients=min_evaluate_clients, min_available_clients=min_available_clients, evaluate_fn=evaluate_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=accept_failures, initial_parameters=initial_parameters, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, ) def __repr__(self) -> str: return "CustomFedAvg" def evaluate( self, server_round: int, parameters: Parameters ) -> Optional[Tuple[float, Dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_fn is None: # No evaluation function provided return None # We deserialize using our custom method parameters_ndarrays = utils.custom_parameters_to_ndarrays(parameters) eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {}) if eval_res is None: return None loss, metrics = eval_res return loss, metrics def aggregate_fit( self, server_round: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} # We deserialize each of the results with our custom method weights_results = [ (utils.custom_parameters_to_ndarrays( fit_res.parameters), fit_res.num_examples) for _, fit_res in results ] aggregated_weight_results = aggregate(weights_results) # Save aggregated_ndarrays on the last round if (server_round == self.num_rounds): # create weights folder if not exist import os if not os.path.exists('./weights'): os.makedirs('./weights') log(INFO, f"Saving round {server_round} aggregated_ndarrays...") np.savez(f"./weights/{self.exp_id}-weights.npz", *aggregated_weight_results) parameters_aggregated = utils.ndarrays_to_custom_parameters( aggregated_weight_results) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.fit_metrics_aggregation_fn: fit_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) elif server_round == 1: # Only log this warning once log(WARNING, "No fit_metrics_aggregation_fn provided") return parameters_aggregated, metrics_aggregated