import torch
import tensorrt as trt
import numpy as np
import time
import os
# Assuming your model class and dependencies are in the same directory or PYTHONPATH
from robofirefusenet import RoboFireFuseNet
def build_engine(model_path, engine_path, input_res=(256, 256)):
"""
Builds a TensorRT engine from an ONNX model with dynamic batch size.
"""
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
# Increase workspace for transformer blocks
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
# Load ONNX model
with open(model_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
# Define optimization profile for dynamic batch size
profile = builder.create_optimization_profile()
input_name = network.get_input(0).name
# Set (Min, Opt, Max) shapes
profile.set_shape(input_name, (1, 4, *input_res), (5, 4, *input_res), (10, 4, *input_res))
config.add_optimization_profile(profile)
# Use FP16 if supported for maximum speed
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print("Using FP16 Precision")
print(f"Building TensorRT engine: {engine_path}...")
serialized_engine = builder.build_serialized_network(network, config)
with open(engine_path, 'wb') as f:
f.write(serialized_engine)
print("Engine built successfully.")
def benchmark_trt(engine_path, batch_size=1, input_res=(256, 256)):
"""
Benchmarks the inference speed of the TensorRT engine.
"""
logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
# Set the input shape for this specific inference run
context.set_input_shape(engine.get_tensor_name(0), (batch_size, 4, *input_res))
# Prepare buffers
device = torch.device('cuda:0')
dummy_input = torch.randn(batch_size, 4, *input_res, device=device)
output_shape = engine.get_tensor_shape(engine.get_tensor_name(1))
# Adjust output shape for dynamic batch
output_shape = (batch_size, *output_shape[1:])
output = torch.empty(output_shape, device=device)
# Warmup
print(f"Warming up batch_size={batch_size}...")
for _ in range(10):
context.execute_v2([dummy_input.data_ptr(), output.data_ptr()])
# Measure speed
print(f"Benchmarking batch_size={batch_size}...")
torch.cuda.synchronize()
start_time = time.time()
iters = 100
for _ in range(iters):
context.execute_v2([dummy_input.data_ptr(), output.data_ptr()])
torch.cuda.synchronize()
end_time = time.time()
avg_latency = (end_time - start_time) / iters
fps = batch_size / avg_latency
print(f"Avg Latency: {avg_latency*1000:.2f} ms | FPS: {fps:.2f}")
if __name__ == '__main__':
device = 'cuda:0'
input_res = (480, 640)
num_classes = 3
onnx_file = "robofirefusenet_dynamic.onnx"
engine_file = "robofirefusenet_dynamic.trt"
# 1. Initialize and Export Model
# Set augment=False for deployment to output a single segmentation map
model = RoboFireFuseNet(m=2, n=3, num_classes=num_classes, planes=32,
ppm_planes=96, head_planes=128, augment=False,
channels=4, input_resolution=input_res,
window_size=(8, 8), tf_depths=[2, 6]).to(device)
model.eval()
dummy_input = torch.randn(1, 4, *input_res).to(device)
print("Exporting to ONNX...")
torch.onnx.export(
model, dummy_input, onnx_file,
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
opset_version=14 # High opset recommended for Transformers
)
# 2. Build TRT Engine
if not os.path.exists(engine_file):
build_engine(onnx_file, engine_file, input_res=input_res)
# 3. Benchmark across requested batch sizes
for b in [1, 5, 10]:
benchmark_trt(engine_file, batch_size=b, input_res=input_res)