import os import numpy as np from PIL import Image import matplotlib.pyplot as plt import sys sys.path.insert(0, './datasets/') from base_dataset import BaseDataset import re import pandas as pd import torch import torchvision class WildFire(BaseDataset): def __init__(self, root, list_path, num_classes=2, multi_scale=True, flip=True, brightness=True, contrast=True, ignore_label=255, base_size=1024, crop_size=(720, 960), scale_factor=16, mean=[0, 0, 0, 0], std=[1, 1, 1, 1], bd_dilate_size=4, seed=200, mode='fusion', blend_images=False, comp_mask=False, single_source=False): self.mean = mean self.std = std super(WildFire, self).__init__(ignore_label, base_size, crop_size, scale_factor, self.mean, self.std) self.root = root self.list_path = list_path self.num_classes = num_classes self.mode = mode self.multi_scale = multi_scale self.flip = flip self.brightness = brightness self.contrast = contrast self.files = [line for line in open(os.path.join(root, list_path)).read().split('\n') if len(line) > 0] self.ignore_label = ignore_label self.color_list = [[0, 0, 0], [125, 125, 125],[255, 255, 255]] if num_classes == 3 else [ (0, 0, 0), # 0: background(unlabeled) (64, 0, 128), # 1: Car (64, 64, 0), # 2: person (0, 128, 192), # 3: bike (0, 0, 192), # 4: curve (128, 128, 0), # 5: car_stop (64, 64, 128), # 6: guardrail (192, 128, 128), # 7: color_cone (192, 64, 0) # 8: bump ] self.bd_dilate_size = bd_dilate_size self.seed = seed self.blend_images = blend_images self.comp_mask = comp_mask self.single_source = single_source def __len__(self): return len(self.files) def color2label(self, color_map): label = np.ones(color_map.shape[:2])*self.ignore_label for i, v in enumerate(self.color_list): label[(color_map == v).sum(2)==3] = i return label.astype(np.uint8) def label2color(self, label): color_map = np.zeros(label.shape+(3,)) for i, v in enumerate(self.color_list): color_map[label==i] = self.color_list[i] return color_map.astype(np.uint8) def random_transformation_matrix(self, scale_range=(0.9, 1.1), rotation_range=(-8, 8), translation_range=(-15, 15)): # Random scaling scale_x = np.random.uniform(*scale_range) scale_y = scale_x # Random rotation theta = np.radians(np.random.uniform(*rotation_range)) cos_theta, sin_theta = np.cos(theta), np.sin(theta) # Random translation trans_x = np.random.uniform(*translation_range) trans_y = np.random.uniform(*translation_range) # Create the transformation matrix transformation_matrix = np.array([ [scale_x * cos_theta, -scale_y * sin_theta, trans_x], [scale_x * sin_theta, scale_y * cos_theta, trans_y], [0, 0, 1] ]) return transformation_matrix, scale_x, (trans_x, trans_y), theta def blend_objects(self, img1_list, mask1, img2_list, mask2): hor_shift, ver_shift = np.random.randint(0, 500), np.random.randint(0, 500) mask1 = np.roll(np.roll(mask1, hor_shift, axis=1), ver_shift, axis=0) for i, img1 in enumerate(img1_list): img1_list[i] = np.roll(np.roll(img1, hor_shift, axis=1), ver_shift, axis=0) obj_ids = np.unique(mask1) obj_ids = obj_ids[obj_ids != 0] if len(obj_ids) == 0: return img2_list, mask2 obj_id = np.random.choice(obj_ids, 1) obj_mask = (mask1 == obj_id) mask2[obj_mask] = mask1[obj_mask] for i, img2 in enumerate(img2_list): img2_list[i][obj_mask] = img1_list[i][obj_mask] return img2_list, mask2 def load_sample(self, index): rgb_img = np.asarray(Image.open(os.path.join(self.root, self.files[index].replace('XXX', 'rgb'))).convert('RGB')).copy() ir_img = np.asarray(Image.open(os.path.join(self.root, self.files[index].replace('XXX', 'ir'))).convert('L')).copy() ir_img = ir_img.reshape(*(ir_img.shape[:2]), -1) try: label_img = np.asarray(Image.open(os.path.join(self.root, self.files[index].replace('XXX', 'gt'))).convert('RGB')).astype('uint8').copy() unique_labels = np.unique(label_img) unique_labels = unique_labels[unique_labels!=self.ignore_label] label_img = label_img[:, :, 0] if unique_labels.max() < 100 else self.color2label(label_img) # dangarous line except: print('no label found') label_img = np.zeros((rgb_img.shape[0], rgb_img.shape[1])).astype('uint8') loaded_images = [rgb_img, ir_img] return loaded_images, label_img def __getitem__(self, index): loaded_images, label = self.load_sample(index) # blend augmentation if self.blend_images and (np.random.rand() < 0.5): for i in range(3): tmp_loaded_images, tmp_label = self.load_sample(np.random.randint(0, len(self))) loaded_images, label = self.blend_objects(tmp_loaded_images, tmp_label, loaded_images, label) images, label, edge = self.gen_sample(loaded_images, label, self.multi_scale, self.flip, edge_pad=False, edge_size=self.bd_dilate_size, brightness=self.brightness, comp_mask=self.comp_mask, single_source=self.single_source) images = np.concatenate(images, axis=0) # return rgb only, ir only or fusion (default) depending on mode if self.mode == 'rgb': images = images[:3] elif self.mode == 'ir': images = images[3] return images.copy(), label.copy(), edge.copy(), [self.files[index].split('/')[-1]] def single_scale_inference(self, config, model, image): pred = self.inference(config, model, image) return pred def save_pred(self, images, labels, preds, name, path): if(len(preds.shape)>3): preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8) else: preds = np.asarray(preds.cpu(), dtype=np.uint8) labels = np.asarray(labels.cpu(), dtype=np.uint8) images = np.asarray(images.cpu(), dtype=np.float32).transpose(0, 2, 3, 1) for i in range(preds.shape[0]): pred = self.label2color(preds[i]) label = self.label2color(labels[i]) if images[i][:, :, 1:3].sum() == 0: img = img[:, :, 3:] else: img = img[:, :, 0:3] image = (((img * self.std) + self.mean) * 255).astype('int') f, ax = plt.subplots(1, 3, figsize=(10,7)) titles = ['Image', 'Ground Truth', 'Prediction'] pics = [image, label, pred] for j in range(3): ax[j].imshow(pics[j]) ax[j].set_title(titles[j]) plt.tight_layout() plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, wspace=0.3, hspace=0.3) plt.savefig(f'{path}/{name[i]}.png', dpi=300) plt.clf() plt.close() if __name__ == '__main__': dataset = WildFire(root='Datasets/', list_path='lists/train_mfnet.txt', num_classes=3, multi_scale=False, flip=True, brightness=True, contrast=True, ignore_label=26, scale_factor=9, crop_size=[272, 336], base_size=336, bd_dilate_size=4, mode='fusion', comp_mask=True, single_source=True) for idx in np.random.choice(len(dataset), 3): images, label, edge, name = dataset[idx] f, ax = plt.subplots(1, len(images) + 2) images = [images[:3], images[3:]] stds, means = [dataset.std_rgb, dataset.std_ir], [dataset.mean_rgb, dataset.mean_ir] for i, img in enumerate(images): img = np.transpose(img, (1, 2, 0)) img = (img * stds[i]) + means[i] images[i] = (img * 255).astype('uint8') label = dataset.label2color(label).astype('uint8') edge = edge.astype('uint8') plt.imsave(f'rgb_image{idx}.png', images[0]) plt.imsave(f'ir_image{idx}.png', images[1][:, :, 0]) plt.imsave(f'label{idx}.png', label) plt.imsave(f'edge{idx}.png', edge)