Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Processing YOLOv8 ONNX Model Output with NMS for Object Detection

Tech May 13 3

Export a trained YOLOv8 model from the .pt format to ONNX.

from ultralytics import YOLO
model_instance = YOLO('path/to/your/best.pt')
export_success = model_instance.export(format="onnx", simplify=True)
assert export_success

Post-processing of the ONNX model output involves three core functions: output standardization, Non-Maximum Suppression (NMS), and Intersection over Union (IoU) calculation.

import numpy as np
import onnxruntime
from PIL import Image
import torchvision.transforms as transforms

def standardize_output(prediction_tensor):
    """
    Standardizes the raw model output.
    Input shape: (1, 7, 8400)
    Output shape: (n_boxes, 8) where columns are: [x_center, y_center, width, height, confidence, ...class_scores]
    """
    prediction_tensor = np.squeeze(prediction_tensor)
    prediction_tensor = np.transpose(prediction_tensor, (1, 0))
    class_predictions = prediction_tensor[..., 4:]
    max_confidence = np.max(class_predictions, axis=-1)
    standardized_output = np.insert(prediction_tensor, 4, max_confidence, axis=-1)
    return standardized_output

def non_max_suppression(detections, iou_threshold):
    """
    Applies Non-Maximum Suppression to detection boxes.
    Args:
        detections: Array of shape (n, 8+) with columns [x_center, y_center, w, h, conf, ...].
        iou_threshold: IoU threshold for suppression.
    Returns:
        Array of filtered detections.
    """
    if len(detections) == 0:
        return np.array([])
    detections = detections[detections[:, 4].argsort()[::-1]]
    selected_boxes = []
    while len(detections) > 0:
        top_box = detections[0]
        selected_boxes.append(top_box)
        if len(detections) == 1:
            break
        remaining_boxes = detections[1:]
        iou_values = calculate_iou(top_box, remaining_boxes)
        detections = remaining_boxes[iou_values < iou_threshold]
    return np.array(selected_boxes)

def calculate_iou(reference_box, comparison_boxes):
    """
    Calculates IoU between a reference box and an array of comparison boxes.
    Boxes are in [x_center, y_center, width, height] format.
    """
    ref_x1 = reference_box[0] - reference_box[2] / 2
    ref_y1 = reference_box[1] - reference_box[3] / 2
    ref_x2 = reference_box[0] + reference_box[2] / 2
    ref_y2 = reference_box[1] + reference_box[3] / 2
    comp_x1 = comparison_boxes[:, 0] - comparison_boxes[:, 2] / 2
    comp_y1 = comparison_boxes[:, 1] - comparison_boxes[:, 3] / 2
    comp_x2 = comparison_boxes[:, 0] + comparison_boxes[:, 2] / 2
    comp_y2 = comparison_boxes[:, 1] + comparison_boxes[:, 3] / 2
    intersect_x1 = np.maximum(ref_x1, comp_x1)
    intersect_y1 = np.maximum(ref_y1, comp_y1)
    intersect_x2 = np.minimum(ref_x2, comp_x2)
    intersect_y2 = np.minimum(ref_y2, comp_y2)
    intersection_area = np.maximum(0, intersect_x2 - intersect_x1) * np.maximum(0, intersect_y2 - intersect_y1)
    area_ref = reference_box[2] * reference_box[3]
    area_comp = comparison_boxes[:, 2] * comparison_boxes[:, 3]
    union_area = area_ref + area_comp - intersection_area
    iou_result = intersection_area / union_area
    return iou_result

The YOLOv8 ONNX model output typically has a shape of (1, 7, 8400). This represents 8400 candidate boxes, each described by 7 values: four for the bounding box coordinates (x_center, y_center, width, height), and three representing class scores (for a model trained on 3 classes). The standardize_output function reshapes this data into a more manageable (8400, 8) format, where the fifth column is the maximum class confidence.

Non-Maximum Suppression (NMS) is a crucial step that eliminates redundant bounding boxes. The algorithm first sorts all detections by confidence. It then iteratively selects the box with the highest confidence, calculates its IoU with all other boxes, and removes those with an IoU above a specified threshold. This process contniues until no boxes remain, ensuring only the most distinct, high-confidence predictions are kept.

Below is a complete implementation for loading the ONNX model, running inference, applying post-processing, and visualizing the results.

# Load the exported ONNX model
model_path = 'path/to/your/best.onnx'
session = onnxruntime.InferenceSession(model_path)

# Preprocess the input image
img_path = 'path/to/test/image.jpg'
original_img = Image.open(img_path)
preprocessor = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])
input_tensor = preprocessor(original_img).unsqueeze(0).numpy()

# Perform inference
model_output = session.run(None, {'images': input_tensor})
raw_predictions = model_output[0]

# Post-process predictions
processed_predictions = standardize_output(raw_predictions)

# Apply confidence thresholding
confidence_cutoff = 0.1
high_confidence_detections = processed_predictions[processed_predictions[:, 4] > confidence_cutoff]

# Apply Non-Maximum Suppression
final_detections = non_max_suppression(high_confidence_detections, iou_threshold=0.2)

# Scale bounding boxes from model input size (640x640) back to original image size
orig_w, orig_h = original_img.size
scale_factor_x = orig_w / 640
scale_factor_y = orig_h / 640

# Draw final bounding boxes on the original image
from PIL import ImageDraw
draw_obj = ImageDraw.Draw(original_img)
for det in final_detections:
    x_cen, y_cen, w, h = det[:4]
    x0 = (x_cen - w / 2) * scale_factor_x
    y0 = (y_cen - h / 2) * scale_factor_y
    x1 = (x_cen + w / 2) * scale_factor_x
    y1 = (y_cen + h / 2) * scale_factor_y
    draw_obj.rectangle([(x0, y0), (x1, y1)], outline="red")

original_img.show()
Tags: pytorch

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.