import itertools as it from enum import Enum from typing import List, Union import numpy as np import torch class Color(Enum): BLACK = (0, 0, 0) WHITE = (255, 255, 255) def __repr__(self) -> str: return repr(self.value) def create_multi_sticker_properties( sticker_properties: Union[List, np.ndarray], sticker_colors: List[Color], ) -> torch.Tensor: """ returns a Tensor containing the properties of 1 or more stickers Tensor is 1 dimensional with the properties stores as follows: [area, *[x1_i, y1_i, x2_i, y2_i, r_i, g_i, b_i] * sticker_count] """ best_sticker = [] area = 0 for i in range(0, len(sticker_properties), 5): x1, y1, x2, y2, color_index = sticker_properties[i : i + 5] color = sticker_colors[color_index].value x1, x2 = sorted((x1, x2)) y1, y2 = sorted((y1, y2)) area += (x2 - x1 + 1) * (y2 - y1 + 1) best_sticker.extend([x1, y1, x2, y2, *color]) return torch.tensor([area, *best_sticker]) def add_multi_sticker_to_image( sticker_props: torch.Tensor, image: torch.Tensor ) -> torch.Tensor: """return a copy of the img with colored sticker(s) added on top of it Args: img (torch.Tensor): the pytorch Tensor of the image sticker_properties (torch.Tensor): Tensor that specifies the sticker properties (area, *[x1,y1,x2,y2,r,g,b] * sticker_count) Returns: (torch.Tensor): copy of the img with sticker(s) placed on it """ copy = torch.clone(image) for i in range(1, len(sticker_props), 7): x1, y1, x2, y2, r, g, b = sticker_props[i : i + 7] copy[0, x1 : x2 + 1, y1 : y2 + 1] = r / 256 copy[1, x1 : x2 + 1, y1 : y2 + 1] = g / 256 copy[2, x1 : x2 + 1, y1 : y2 + 1] = b / 256 return copy def sticker_batch( img: torch.Tensor, batch: torch.Tensor, upper_bound: int ) -> torch.Tensor: """return a batch of copys of the img with colored stickers added on top of it Args: img (torch.Tensor): the Tensor of the image sticker_properties (torch.Tensor): Tensor that specifies a batch of sticker properties Returns: (torch.Tensor): copy of the img batch with stickers placed on it """ return torch.stack( list( map( lambda properties: add_multi_sticker_to_image(properties, img), # only add stickers on imgs in the batch that are a possible solutions (i.e if the upperbound solution is bigger) # therefore the batch_size might vary if the search is done from high heuristic to low # this won't happen if searched from low to high (it.takewhile(lambda properties: upper_bound > properties[0], batch)), ) ) )