FedSecurity / defense / robust_learning_rate_defense.py
robust_learning_rate_defense.py
Raw
from collections import OrderedDict
import torch
from ..common.utils import get_total_sample_num
from .defense_base import BaseDefenseMethod
from typing import Callable, List, Tuple, Dict, Any


class RobustLearningRateDefense(BaseDefenseMethod):
    def __init__(self, config):
        self.robust_threshold = config.robust_threshold  # e.g., robust threshold = 4
        self.server_learning_rate = 1

    def run(
            self,
            raw_client_grad_list: List[Tuple[float, OrderedDict]],
            base_aggregation_func: Callable = None,
            extra_auxiliary_info: Any = None,
    ):
        if self.robust_threshold == 0:
            return base_aggregation_func(raw_client_grad_list)  # avg_params
        total_sample_num = get_total_sample_num(raw_client_grad_list)
        (num0, avg_params) = raw_client_grad_list[0]
        for k in avg_params.keys():
            client_update_sign = []  # self._compute_robust_learning_rates(model_list)
            for i in range(0, len(raw_client_grad_list)):
                local_sample_number, local_model_params = raw_client_grad_list[i]
                client_update_sign.append(torch.sign(local_model_params[k]))
                w = local_sample_number / total_sample_num
                if i == 0:
                    avg_params[k] = local_model_params[k] * w
                else:
                    avg_params[k] += local_model_params[k] * w
            client_lr = self._compute_robust_learning_rates(client_update_sign)
            avg_params[k] = client_lr * avg_params[k]
        return avg_params

    def _compute_robust_learning_rates(self, client_update_sign):
        client_lr = torch.abs(sum(client_update_sign))
        client_lr[client_lr < self.robust_threshold] = -self.server_learning_rate
        client_lr[client_lr >= self.robust_threshold] = self.server_learning_rate
        return client_lr