auto-fl-fit / pycaret-fl / strategy / fedensemble.py
fedensemble.py
Raw
import numpy as np
from flwr.server.strategy import FedAvg

from flwr.server.client_proxy import ClientProxy
from typing import Callable, Dict, List, Optional, Tuple, Union
from flwr.common import (FitRes, MetricsAggregationFn,
                         NDArrays, Parameters, Scalar,
                         MetricsAggregationFn)
from flwr.common.logger import log
from logging import WARNING, INFO

import utils

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""

# Flower strategy that picks the best model based on accuracy and loss function


def flower_strategy(models, accuracy_scores, loss_scores):
    # Calculate the average accuracy and loss scores for each model
    avg_accuracy = np.mean(accuracy_scores)
    avg_loss = np.mean(loss_scores)

    # Find the index of the model with the highest accuracy score
    best_accuracy_index = np.argmax(accuracy_scores)

    # Find the index of the model with the lowest loss score
    best_loss_index = np.argmin(loss_scores)

    # Check if the model with the highest accuracy score also has the lowest loss score
    if best_accuracy_index == best_loss_index:
        return models[best_accuracy_index]
    else:
        # Calculate the weighted score for each model based on accuracy and loss
        weighted_scores = (accuracy_scores - avg_accuracy) + \
            (avg_loss - loss_scores)

        # Find the index of the model with the highest weighted score
        best_weighted_index = np.argmax(weighted_scores)

        return models[best_weighted_index]


class FedEnsemble(FedAvg):
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[
            int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        num_rounds: int = 10,
        exp_id: str = "test",
        problem_type: str
    ) -> None:
        if (
            min_fit_clients > min_available_clients
            or min_evaluate_clients > min_available_clients
        ):
            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

        self.num_rounds = num_rounds
        self.exp_id = exp_id
        self.problem_type = problem_type

        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

    def __repr__(self) -> str:
        return "FedEnsemble"

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate training results. In this strategy we choose the weights with the best accuracy and loss scores."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        def get_best(result: FitRes):
            # Get the best model based loss scores
            return utils.get_loss(problem_type=self.problem_type, metrics=result.metrics)

        aggregated_weight_results = utils.custom_parameters_to_ndarrays(
            min([fit_res for _, fit_res in results], key=get_best).parameters)

        # Save aggregated_ndarrays on the last round
        if (server_round == self.num_rounds):
            # create weights folder if not exist
            import os
            if not os.path.exists('./weights'):
                os.makedirs('./weights')

            log(INFO, f"Saving round {server_round} aggregated_ndarrays...")
            np.savez(f"./weights/{self.exp_id}-weights.npz",
                     *aggregated_weight_results)

        parameters_aggregated = utils.ndarrays_to_custom_parameters(
            aggregated_weight_results)

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics)
                           for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate the current model parameters.
        Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None

        # We deserialize using our custom method
        parameters_ndarrays = utils.custom_parameters_to_ndarrays(parameters)

        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res

        return loss, metrics