FedSecurity / attack / label_flipping_attack.py
label_flipping_attack.py
Raw
import logging
import math
import random
import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from .attack_base import BaseAttackMethod
import numpy as np
from ..common.utils import (
    replace_original_class_with_target_class,
    get_client_data_stat,
)

"""
ref: Tolpegin, Vale, Truex,  "Data Poisoning Attacks Against Federated Learning Systems."  (2021).
attack @client, added by Yuhui, 07/08/2022
"""


class LabelFlippingAttack(BaseAttackMethod):
    def __init__(self, args):
        self.original_class_list = args.original_class_list
        self.target_class_list = args.target_class_list
        self.batch_size = args.batch_size
        self.poisoned_client_list = []
        self.ite_counter = 0
        if hasattr(args, "poison_start_round_id") and isinstance(args.poison_start_round_id, int):
            self.poison_start_round_id = args.poison_start_round_id
        else:
            self.poison_start_round_id = 0
        if hasattr(args, "poison_end_round_id") and isinstance(args.poison_end_round_id, int):
            self.poison_end_round_id = args.poison_end_round_id
        else:
            self.poison_end_round_id = args.comm_round - 1
        if hasattr(args, "ratio_of_poisoned_client") and isinstance(args.ratio_of_poisoned_client, float):
            if args.ratio_of_poisoned_client < 0 or args.ratio_of_poisoned_client > 1:
                raise Exception("unknown ratio_of_poisoned_client")
            self.ratio_of_poisoned_client = args.ratio_of_poisoned_client
        else:
            raise Exception("unknown poisoned client number")

        self.client_num_per_round = args.client_num_per_round
        self.counter = 0

    def get_ite_num(self):
        return math.floor(self.counter / self.client_num_per_round)  # ite num starts from 0

    def is_to_poison_data(self):
        self.counter += 1
        if self.get_ite_num() < self.poison_start_round_id or self.get_ite_num() > self.poison_end_round_id:
            return False
        np.random.seed(self.counter)
        rand = np.random.random()
        # rand = random.random()
        return rand < self.ratio_of_poisoned_client

    def print_dataset(self, dataset):
        print("---------------print dataset------------")
        for batch_idx, (data, target) in enumerate(dataset):
            print(f"{batch_idx} ----- {target}")

    def poison_data(self, local_dataset):
        get_client_data_stat(local_dataset)
        # print("=======================1 end ")
        # self.print_dataset(local_dataset)
        # get_client_data_stat(local_dataset)
        # print("======================= 2 end")
        tmp_local_dataset_x = torch.Tensor([])
        tmp_local_dataset_y = torch.Tensor([])
        targets_set = {}
        for batch_idx, (data, targets) in enumerate(local_dataset):
            tmp_local_dataset_x = torch.cat((tmp_local_dataset_x, data))
            tmp_local_dataset_y = torch.cat((tmp_local_dataset_y, targets))

            for t in targets.tolist():
                if t in targets_set.keys():
                    targets_set[t] += 1
                else:
                    targets_set[t] = 1
        total_counter = 0
        for item in targets_set.items():
            # print("------target:{} num:{}".format(item[0], item[1]))
            total_counter += item[1]
        # print(f"total counter = {total_counter}")

        ####################### below are correct ###############################3

        tmp_y = replace_original_class_with_target_class(
            data_labels=tmp_local_dataset_y,
            original_class_list=self.original_class_list,
            target_class_list=self.target_class_list,
        )
        dataset = TensorDataset(tmp_local_dataset_x, tmp_y)
        poisoned_data = DataLoader(dataset, batch_size=self.batch_size)
        get_client_data_stat(poisoned_data)

        return poisoned_data