# Academic Software License: Copyright © 2024. import os import json import time import torch import argparse import traceback import numpy as np from utils import is_inside_bbox from pyquaternion import Quaternion as Q RAD_NAME = ['RAD_LEFT', 'RAD_FRONT', 'RAD_RIGHT', 'RAD_BACK'] DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' argparser = argparse.ArgumentParser(description='SimBEV post-processing tool.') argparser.add_argument( '--path', default='/dataset', help='path for saving the dataset (default: /dataset)') args = argparser.parse_args() def main(): try: os.makedirs(f'{args.path}/simbev/ground-truth/new_det', exist_ok=True) for split in ['train', 'val', 'test']: if os.path.exists(f'{args.path}/simbev/infos/simbev_infos_{split}.json'): with open(f'{args.path}/simbev/infos/simbev_infos_{split}.json', 'r') as f: infos = json.load(f) metadata = infos['metadata'] for scene in infos['data']: scene_number = int(scene[-4:]) print(f'Processing scene {scene_number}...') if infos['data'][scene]['scene_info']['map'] in ['Town12', 'Town13', 'Town15']: dType = torch.double else: dType = torch.float32 for i, info in enumerate(infos['data'][scene]['scene_data']): # Ego to global transformation. ego2global = np.eye(4).astype(np.float32) ego2global[:3, :3] = Q(info['ego2global_rotation']).rotation_matrix ego2global[:3, 3] = info['ego2global_translation'] # Lidar to ego transformation. lidar2ego = np.eye(4).astype(np.float32) lidar2ego[:3, :3] = Q(metadata['LIDAR']['sensor2ego_rotation']).rotation_matrix lidar2ego[:3, 3] = metadata['LIDAR']['sensor2ego_translation'] lidar2global = ego2global @ lidar2ego # Load lidar point cloud. lidar_path = info['LIDAR'] if lidar_path.endswith('.npz'): lidar_points = torch.from_numpy(np.load(lidar_path)['data']).to(DEVICE, dType) else: lidar_points = torch.from_numpy(np.load(lidar_path)).to(DEVICE, dType) dim0 = lidar_points.shape[0] lidar_points_global = ( torch.from_numpy(lidar2global).to(DEVICE, dType) @ torch.cat( (lidar_points, torch.ones(dim0, 1).to(DEVICE, dType)), dim=1 ).T )[:3].T radar_points_global = torch.empty((0, 3), device=DEVICE, dtype=dType) for radar in RAD_NAME: # Radar to ego transformation. radar2ego = np.eye(4).astype(np.float32) radar2ego[:3, :3] = Q(metadata[radar]['sensor2ego_rotation']).rotation_matrix radar2ego[:3, 3] = metadata[radar]['sensor2ego_translation'] radar2global = ego2global @ radar2ego # Load radar point cloud. radar_path = info[radar] if radar_path.endswith('.npz'): radar_points = torch.from_numpy(np.load(radar_path)['data']).to(DEVICE, dType) else: radar_points = torch.from_numpy(np.load(radar_path)).to(DEVICE, dType) # Transform depth, altitude, and azimuth data to x, y, and z. radar_points = radar_points[:, :-1] x = radar_points[:, 0] * torch.cos(radar_points[:, 1]) * torch.cos(radar_points[:, 2]) y = radar_points[:, 0] * torch.cos(radar_points[:, 1]) * torch.sin(radar_points[:, 2]) z = radar_points[:, 0] * torch.sin(radar_points[:, 1]) points = torch.stack((x, y, z), dim=1) dim0 = points.shape[0] points_global = ( torch.from_numpy(radar2global).to(DEVICE, dType) @ torch.cat( (points, torch.ones(dim0, 1).to(DEVICE, dType)), dim=1 ).T )[:3].T radar_points_global = torch.cat((radar_points_global, points_global), dim=0) # Load object bounding boxes. det_objects = np.load(info['GT_DET'], allow_pickle=True) new_det_objects = [] for obj in det_objects: lidar_mask = is_inside_bbox(lidar_points_global, obj['bounding_box'], DEVICE, dType) radar_mask = is_inside_bbox(radar_points_global, obj['bounding_box'], DEVICE, dType) num_lidar_points = lidar_mask.sum().item() num_radar_points = radar_mask.sum().item() valid_flag = (num_lidar_points > 0) or (num_radar_points > 0) obj['num_lidar_pts'] = num_lidar_points obj['num_radar_pts'] = num_radar_points obj['valid_flag'] = valid_flag new_det_objects.append(obj) with open( f'{args.path}/simbev/ground-truth/new_det/SimBEV-scene-{scene_number:04d}-frame-{i:04d}-GT_DET.bin', 'wb' ) as f: np.save(f, np.array(new_det_objects), allow_pickle=True) except Exception: print(traceback.format_exc()) print('Killing the process...') time.sleep(3.0) if __name__ == '__main__': try: main() except KeyboardInterrupt: print('Killing the process...') time.sleep(3.0) finally: print('Done.')