RoboFireFuseNet-private / models / robo_to_trt.py
robo_to_trt.py
Raw
import torch
import tensorrt as trt
import numpy as np
import time
import os

# Assuming your model class and dependencies are in the same directory or PYTHONPATH
from robofirefusenet import RoboFireFuseNet

def build_engine(model_path, engine_path, input_res=(256, 256)):
    """
    Builds a TensorRT engine from an ONNX model with dynamic batch size.
    """
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, logger)
    config = builder.create_builder_config()
    
    # Increase workspace for transformer blocks
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
    
    # Load ONNX model
    with open(model_path, 'rb') as model:
        if not parser.parse(model.read()):
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None

    # Define optimization profile for dynamic batch size
    profile = builder.create_optimization_profile()
    input_name = network.get_input(0).name
    # Set (Min, Opt, Max) shapes
    profile.set_shape(input_name, (1, 4, *input_res), (5, 4, *input_res), (10, 4, *input_res))
    config.add_optimization_profile(profile)

    # Use FP16 if supported for maximum speed
    if builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)
        print("Using FP16 Precision")

    print(f"Building TensorRT engine: {engine_path}...")
    serialized_engine = builder.build_serialized_network(network, config)
    with open(engine_path, 'wb') as f:
        f.write(serialized_engine)
    print("Engine built successfully.")

def benchmark_trt(engine_path, batch_size=1, input_res=(256, 256)):
    """
    Benchmarks the inference speed of the TensorRT engine.
    """
    logger = trt.Logger(trt.Logger.WARNING)
    with open(engine_path, 'rb') as f, trt.Runtime(logger) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
        context = engine.create_execution_context()

    # Set the input shape for this specific inference run
    context.set_input_shape(engine.get_tensor_name(0), (batch_size, 4, *input_res))
    
    # Prepare buffers
    device = torch.device('cuda:0')
    dummy_input = torch.randn(batch_size, 4, *input_res, device=device)
    output_shape = engine.get_tensor_shape(engine.get_tensor_name(1))
    # Adjust output shape for dynamic batch
    output_shape = (batch_size, *output_shape[1:])
    output = torch.empty(output_shape, device=device)

    # Warmup
    print(f"Warming up batch_size={batch_size}...")
    for _ in range(10):
        context.execute_v2([dummy_input.data_ptr(), output.data_ptr()])

    # Measure speed
    print(f"Benchmarking batch_size={batch_size}...")
    torch.cuda.synchronize()
    start_time = time.time()
    iters = 100
    for _ in range(iters):
        context.execute_v2([dummy_input.data_ptr(), output.data_ptr()])
    torch.cuda.synchronize()
    end_time = time.time()

    avg_latency = (end_time - start_time) / iters
    fps = batch_size / avg_latency
    print(f"Avg Latency: {avg_latency*1000:.2f} ms | FPS: {fps:.2f}")

if __name__ == '__main__':
    device = 'cuda:0'
    input_res = (480, 640)
    num_classes = 3
    onnx_file = "robofirefusenet_dynamic.onnx"
    engine_file = "robofirefusenet_dynamic.trt"

    # 1. Initialize and Export Model
    # Set augment=False for deployment to output a single segmentation map
    model = RoboFireFuseNet(m=2, n=3, num_classes=num_classes, planes=32, 
                            ppm_planes=96, head_planes=128, augment=False, 
                            channels=4, input_resolution=input_res, 
                            window_size=(8, 8), tf_depths=[2, 6]).to(device)
    model.eval()

    dummy_input = torch.randn(1, 4, *input_res).to(device)
    
    print("Exporting to ONNX...")
    torch.onnx.export(
        model, dummy_input, onnx_file,
        input_names=['input'], output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
        opset_version=14 # High opset recommended for Transformers
    )

    # 2. Build TRT Engine
    if not os.path.exists(engine_file):
        build_engine(onnx_file, engine_file, input_res=input_res)

    # 3. Benchmark across requested batch sizes
    for b in [1, 5, 10]:
        benchmark_trt(engine_file, batch_size=b, input_res=input_res)