# Academic Software License: Copyright © 2024.
import os
import json
import time
import torch
import argparse
import traceback
import numpy as np
from utils import is_inside_bbox
from pyquaternion import Quaternion as Q
RAD_NAME = ['RAD_LEFT', 'RAD_FRONT', 'RAD_RIGHT', 'RAD_BACK']
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
argparser = argparse.ArgumentParser(description='SimBEV post-processing tool.')
argparser.add_argument(
'--path',
default='/dataset',
help='path for saving the dataset (default: /dataset)')
args = argparser.parse_args()
def main():
try:
os.makedirs(f'{args.path}/simbev/ground-truth/new_det', exist_ok=True)
for split in ['train', 'val', 'test']:
if os.path.exists(f'{args.path}/simbev/infos/simbev_infos_{split}.json'):
with open(f'{args.path}/simbev/infos/simbev_infos_{split}.json', 'r') as f:
infos = json.load(f)
metadata = infos['metadata']
for scene in infos['data']:
scene_number = int(scene[-4:])
print(f'Processing scene {scene_number}...')
if infos['data'][scene]['scene_info']['map'] in ['Town12', 'Town13', 'Town15']:
dType = torch.double
else:
dType = torch.float32
for i, info in enumerate(infos['data'][scene]['scene_data']):
# Ego to global transformation.
ego2global = np.eye(4).astype(np.float32)
ego2global[:3, :3] = Q(info['ego2global_rotation']).rotation_matrix
ego2global[:3, 3] = info['ego2global_translation']
# Lidar to ego transformation.
lidar2ego = np.eye(4).astype(np.float32)
lidar2ego[:3, :3] = Q(metadata['LIDAR']['sensor2ego_rotation']).rotation_matrix
lidar2ego[:3, 3] = metadata['LIDAR']['sensor2ego_translation']
lidar2global = ego2global @ lidar2ego
# Load lidar point cloud.
lidar_path = info['LIDAR']
if lidar_path.endswith('.npz'):
lidar_points = torch.from_numpy(np.load(lidar_path)['data']).to(DEVICE, dType)
else:
lidar_points = torch.from_numpy(np.load(lidar_path)).to(DEVICE, dType)
dim0 = lidar_points.shape[0]
lidar_points_global = (
torch.from_numpy(lidar2global).to(DEVICE, dType) @ torch.cat(
(lidar_points, torch.ones(dim0, 1).to(DEVICE, dType)), dim=1
).T
)[:3].T
radar_points_global = torch.empty((0, 3), device=DEVICE, dtype=dType)
for radar in RAD_NAME:
# Radar to ego transformation.
radar2ego = np.eye(4).astype(np.float32)
radar2ego[:3, :3] = Q(metadata[radar]['sensor2ego_rotation']).rotation_matrix
radar2ego[:3, 3] = metadata[radar]['sensor2ego_translation']
radar2global = ego2global @ radar2ego
# Load radar point cloud.
radar_path = info[radar]
if radar_path.endswith('.npz'):
radar_points = torch.from_numpy(np.load(radar_path)['data']).to(DEVICE, dType)
else:
radar_points = torch.from_numpy(np.load(radar_path)).to(DEVICE, dType)
# Transform depth, altitude, and azimuth data to x, y, and z.
radar_points = radar_points[:, :-1]
x = radar_points[:, 0] * torch.cos(radar_points[:, 1]) * torch.cos(radar_points[:, 2])
y = radar_points[:, 0] * torch.cos(radar_points[:, 1]) * torch.sin(radar_points[:, 2])
z = radar_points[:, 0] * torch.sin(radar_points[:, 1])
points = torch.stack((x, y, z), dim=1)
dim0 = points.shape[0]
points_global = (
torch.from_numpy(radar2global).to(DEVICE, dType) @ torch.cat(
(points, torch.ones(dim0, 1).to(DEVICE, dType)), dim=1
).T
)[:3].T
radar_points_global = torch.cat((radar_points_global, points_global), dim=0)
# Load object bounding boxes.
det_objects = np.load(info['GT_DET'], allow_pickle=True)
new_det_objects = []
for obj in det_objects:
lidar_mask = is_inside_bbox(lidar_points_global, obj['bounding_box'], DEVICE, dType)
radar_mask = is_inside_bbox(radar_points_global, obj['bounding_box'], DEVICE, dType)
num_lidar_points = lidar_mask.sum().item()
num_radar_points = radar_mask.sum().item()
valid_flag = (num_lidar_points > 0) or (num_radar_points > 0)
obj['num_lidar_pts'] = num_lidar_points
obj['num_radar_pts'] = num_radar_points
obj['valid_flag'] = valid_flag
new_det_objects.append(obj)
with open(
f'{args.path}/simbev/ground-truth/new_det/SimBEV-scene-{scene_number:04d}-frame-{i:04d}-GT_DET.bin',
'wb'
) as f:
np.save(f, np.array(new_det_objects), allow_pickle=True)
except Exception:
print(traceback.format_exc())
print('Killing the process...')
time.sleep(3.0)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print('Killing the process...')
time.sleep(3.0)
finally:
print('Done.')