traffic-sign-classifier-robustness-testing / rt_search_based / transformations / stickers.py
stickers.py
Raw
import itertools as it
from enum import Enum
from typing import List, Union

import numpy as np
import torch


class Color(Enum):
    BLACK = (0, 0, 0)
    WHITE = (255, 255, 255)

    def __repr__(self) -> str:
        return repr(self.value)


def create_multi_sticker_properties(
    sticker_properties: Union[List, np.ndarray],
    sticker_colors: List[Color],
) -> torch.Tensor:
    """
    returns a Tensor containing the properties of 1 or more stickers

    Tensor is 1 dimensional with the properties stores as follows:
        [area, *[x1_i, y1_i, x2_i, y2_i, r_i, g_i, b_i] * sticker_count]
    """

    best_sticker = []
    area = 0
    for i in range(0, len(sticker_properties), 5):
        x1, y1, x2, y2, color_index = sticker_properties[i : i + 5]

        color = sticker_colors[color_index].value
        x1, x2 = sorted((x1, x2))
        y1, y2 = sorted((y1, y2))
        area += (x2 - x1 + 1) * (y2 - y1 + 1)
        best_sticker.extend([x1, y1, x2, y2, *color])
    return torch.tensor([area, *best_sticker])


def add_multi_sticker_to_image(
    sticker_props: torch.Tensor, image: torch.Tensor
) -> torch.Tensor:
    """return a copy of the img with colored sticker(s) added on top of it

    Args:
        img (torch.Tensor): the pytorch Tensor of the image
        sticker_properties (torch.Tensor): Tensor that specifies the sticker properties (area, *[x1,y1,x2,y2,r,g,b] * sticker_count)

    Returns:
        (torch.Tensor):
            copy of the img with sticker(s) placed on it
    """

    copy = torch.clone(image)

    for i in range(1, len(sticker_props), 7):
        x1, y1, x2, y2, r, g, b = sticker_props[i : i + 7]

        copy[0, x1 : x2 + 1, y1 : y2 + 1] = r / 256
        copy[1, x1 : x2 + 1, y1 : y2 + 1] = g / 256
        copy[2, x1 : x2 + 1, y1 : y2 + 1] = b / 256

    return copy


def sticker_batch(
    img: torch.Tensor, batch: torch.Tensor, upper_bound: int
) -> torch.Tensor:
    """return a batch of copys of the img with colored stickers added on top of it

    Args:
        img (torch.Tensor): the Tensor of the image
        sticker_properties (torch.Tensor): Tensor that specifies a batch of sticker properties
    Returns:
        (torch.Tensor): copy of the img batch with stickers placed on it
    """

    return torch.stack(
        list(
            map(
                lambda properties: add_multi_sticker_to_image(properties, img),
                # only add stickers on imgs in the batch that are a possible solutions (i.e if the upperbound solution is bigger)
                # therefore the batch_size might vary if the search is done from high heuristic to low
                # this won't happen if searched from low to high
                (it.takewhile(lambda properties: upper_bound > properties[0], batch)),
            )
        )
    )