import sys import random import os, numpy as np import torch import torchvision.transforms as transforms import torch.utils.data as data # from skimage.transform import resize from scipy.sparse import csr_matrix from PIL import Image import xml.etree.ElementTree as ET import cv2 import matplotlib.pyplot as plt VOC_CLASSES = ('__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') class VocDataset(data.Dataset): def __init__(self, data_path, dataset_split, transform, random_crops=0): self.data_path = data_path self.transform = transform self.random_crops = random_crops self.dataset_split = dataset_split self.__init_classes() self.names, self.labels, self.box_indices, self.label_order = self.__dataset_info() def __getitem__(self, index): # CHANGED # x = imread(self.data_path + '/JPEGImages/' + self.names[index] + '.jpg', mode='RGB') # x = Image.fromarray(x) x = Image.open(self.data_path + '/JPEGImages/' + self.names[index] + '.jpg') scale = np.random.rand() * 2 + 0.25 w = int(x.size[0] * scale) h = int(x.size[1] * scale) if min(w, h) < 227: scale = 227 / min(w, h) w = int(x.size[0] * scale) h = int(x.size[1] * scale) if self.random_crops == 0: x = self.transform(x) else: crops = [] for i in range(self.random_crops): crops.append(self.transform(x)) x = torch.stack(crops) y = self.labels[index] z = self.box_indices[index] return x, y, z def __len__(self): return len(self.names) def __init_classes(self): self.classes = VOC_CLASSES self.num_classes = len(self.classes) self.class_to_ind = dict(zip(self.classes, range(self.num_classes))) def __dataset_info(self): with open(self.data_path + '/ImageSets/Main/' + self.dataset_split + '.txt') as f: annotations = f.readlines() annotations = [n[:-1] for n in annotations] box_indices = [] names = [] labels = [] label_order = [] for af in annotations: if len(af) != 6: continue filename = os.path.join(self.data_path, 'Annotations', af) tree = ET.parse(filename + '.xml') objs = tree.findall('object') num_objs = len(objs) boxes = np.zeros((num_objs, 4), dtype=np.int32) boxes_cl = np.zeros((num_objs), dtype=np.int32) boxes_cla = [] temp_label = [] for ix, obj in enumerate(objs): bbox = obj.find('bndbox') # Make pixel indexes 0-based x1 = float(bbox.find('xmin').text) - 1 y1 = float(bbox.find('ymin').text) - 1 x2 = float(bbox.find('xmax').text) - 1 y2 = float(bbox.find('ymax').text) - 1 cls = self.class_to_ind[obj.find('name').text.lower().strip()] boxes[ix, :] = [x1, y1, x2, y2] boxes_cl[ix] = cls boxes_cla.append(boxes[ix, :]) temp_label.append(cls) lbl = np.zeros(self.num_classes) lbl[boxes_cl] = 1 labels.append(lbl) names.append(af) box_indices.append(boxes_cla) label_order.append(temp_label) return np.array(names), np.array(labels).astype(np.float32), np.array(box_indices), label_order