Fading Coder

An Old Coder’s Final Dance

Home > Tech > Content

Implementing mean Average Precision (mAP) for Object Detection in PyTorch

Tech 1

mAP (mean Average Precision) is a standard metric for evaluating object detectors. It summarizes precision–recall tradeoffs across cateegories, optionally averaged over multiple IoU thresholds.

  • Precision = TP / (TP + FP)
  • Recall = TP / (TP + FN)
  • AP (per class) = area under the precision–recall curve
  • mAP = mean of AP over classes (and, in some protocols, also averaged over IoU thresholds)

Below is a concise walkthrough of how AP is computed for one clas, followed by a complete PyTorch implementation.

Per‑class AP computation

  1. Collect per-class detections and ground truths

    • Assume predictions are post-NMS.
    • Work image-by-image and class-by-class.
  2. Match detections to ground truths (TP/FP assignment)

    • Sort detections by confidence in descending order.
    • For each detection, compute IoU with ground-truth boxes from the same image and class.
    • If the maximum IoU ≥ IoU threshold and that ground-truth box is not yet matched, mark the detection as TP and mark that GT as used. Otherwise, mark it as FP.
  3. Build the precision–recall curve

    • Compute cumulative TP and FP over the sorted detections.
    • Recalll = cumTP / total_gt
    • Precision = cumTP / (cumTP + cumFP)
  4. Integrate to get AP

    • Optionally apply the precision envelope (monotonic precision) before integration.
    • Numerically integrate precision w.r.t. recall to obtain AP.
  5. Average AP across classes (and optionally across IoU thresholds) to obtain mAP

    • Skip classes with no ground truths.

PyTorch implementation

This implementation expects detections and ground truths in the following format:

  • Each entry: [image_id, class_id, confidence_score, x1, y1, x2, y2]
  • For ground truths, confidence_score can be any placeholder (it is ignored).
  • Coordinates use [x1, y1, x2, y2] (top-left, bottom-right) and can be normalized or absolute, as long as both inputs use the same convention.
import torch
from collections import defaultdict
from typing import List, Sequence, Tuple

# boxes are [x1, y1, x2, y2] in "xyxy" format

def iou_xyxy(box: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor:
    """Compute IoU between one box (4,) and many boxes (N, 4).
    Args:
        box: Tensor of shape (4,)
        boxes: Tensor of shape (N, 4)
    Returns:
        Tensor of shape (N,) with IoU for each boxes[i].
    """
    # Intersection
    tl = torch.maximum(box[:2], boxes[:, :2])  # top-left
    br = torch.minimum(box[2:], boxes[:, 2:])  # bottom-right
    inter_wh = (br - tl).clamp(min=0)
    inter_area = inter_wh[:, 0] * inter_wh[:, 1]

    # Areas
    box_area = (box[2] - box[0]).clamp(min=0) * (box[3] - box[1]).clamp(min=0)
    boxes_area = ((boxes[:, 2] - boxes[:, 0]).clamp(min=0) *
                  (boxes[:, 3] - boxes[:, 1]).clamp(min=0))

    # Union
    union = box_area + boxes_area - inter_area + 1e-6
    return inter_area / union


def mean_average_precision(
    preds: List[Sequence[float]],
    gts: List[Sequence[float]],
    iou_threshold: float = 0.5,
    num_classes: int = 80,
    use_precision_envelope: bool = True,
) -> Tuple[float, List[float]]:
    """Compute mAP and per-class AP using all-point interpolation (VOC-style integration).

    Args:
        preds: List of detections [img_id, cls_id, score, x1, y1, x2, y2].
        gts:   List of ground truths [img_id, cls_id, _,   x1, y1, x2, y2].
        iou_threshold: IoU threshold for a detection to be considered TP.
        num_classes: Total number of classes.
        use_precision_envelope: If True, apply monotonic precision envelope.

    Returns:
        (mAP, ap_per_class): scalar mAP and list of APs for classes that have GTs.
    """
    device = torch.device("cpu")
    eps = 1e-6
    ap_values: List[float] = []

    # Group ground truths by (class -> image -> boxes)
    gt_by_class: List[defaultdict] = [defaultdict(list) for _ in range(num_classes)]
    for gt in gts:
        img_id, cls_id = int(gt[0]), int(gt[1])
        box = torch.tensor(gt[3:7], dtype=torch.float32, device=device)
        gt_by_class[cls_id][img_id].append(box)

    # Convert lists to tensors and create matched flags
    gt_boxes: List[dict] = [{} for _ in range(num_classes)]
    gt_used: List[dict] = [{} for _ in range(num_classes)]
    for c in range(num_classes):
        for img_id, boxes in gt_by_class[c].items():
            if len(boxes) > 0:
                gt_boxes[c][img_id] = torch.stack(boxes).to(device)  # (M, 4)
                gt_used[c][img_id] = torch.zeros(len(boxes), dtype=torch.bool, device=device)

    # Group detections by class and sort by confidence
    det_by_class: List[List[Tuple[int, float, torch.Tensor]]] = [[] for _ in range(num_classes)]
    for det in preds:
        img_id, cls_id = int(det[0]), int(det[1])
        score = float(det[2])
        box = torch.tensor(det[3:7], dtype=torch.float32, device=device)
        det_by_class[cls_id].append((img_id, score, box))

    for c in range(num_classes):
        detections_c = det_by_class[c]
        total_gt = sum(len(v) for v in gt_boxes[c].values())
        if total_gt == 0:
            # No GT for this class: exclude from mAP computation
            continue

        if not detections_c:
            # GT exists but no detections => AP = 0
            ap_values.append(0.0)
            continue

        # Sort by confidence descending
        detections_c.sort(key=lambda t: t[1], reverse=True)

        tp = torch.zeros(len(detections_c), dtype=torch.float32, device=device)
        fp = torch.zeros(len(detections_c), dtype=torch.float32, device=device)

        # Assign TP/FP
        for i, (img_id, score, box_det) in enumerate(detections_c):
            if img_id not in gt_boxes[c]:
                fp[i] = 1.0
                continue

            boxes_gt_img = gt_boxes[c][img_id]
            used_flags = gt_used[c][img_id]

            ious = iou_xyxy(box_det, boxes_gt_img)
            best_iou, best_idx = (float(ious.max().item()), int(ious.argmax().item())) if boxes_gt_img.numel() > 0 else (0.0, -1)

            if best_iou >= iou_threshold and not used_flags[best_idx]:
                tp[i] = 1.0
                used_flags[best_idx] = True
            else:
                fp[i] = 1.0

        # Precision–Recall
        cum_tp = torch.cumsum(tp, dim=0)
        cum_fp = torch.cumsum(fp, dim=0)
        recall = cum_tp / (total_gt + eps)
        precision = cum_tp / (cum_tp + cum_fp + eps)

        # Precision envelope (monotonic)
        if use_precision_envelope and precision.numel() > 0:
            for k in range(precision.numel() - 2, -1, -1):
                precision[k] = torch.maximum(precision[k], precision[k + 1])

        # Add sentinel points and integrate
        mrec = torch.cat([torch.tensor([0.0], device=device), recall])
        mpre = torch.cat([torch.tensor([1.0], device=device), precision])

        ap = float(torch.trapz(mpre, mrec).item()) if mrec.numel() > 1 else 0.0
        ap_values.append(ap)

    if not ap_values:
        return 0.0, []

    mAP = float(sum(ap_values) / len(ap_values))
    return mAP, ap_values

Usage example

# Example inputs: [img_id, cls_id, score, x1, y1, x2, y2]
pred_bboxes = [
    [0, 1, 0.9, 10, 10, 50, 50],
    [0, 1, 0.6, 12, 12, 48, 48],
    [1, 1, 0.8, 30, 30, 70, 70],
    [1, 0, 0.7, 15, 15, 40, 40],
]

true_bboxes = [
    [0, 1, 1.0, 11, 11, 49, 49],
    [1, 1, 1.0, 32, 32, 68, 68],
    [1, 0, 1.0, 16, 16, 39, 39],
]

mAP, ap_per_class = mean_average_precision(
    preds=pred_bboxes,
    gts=true_bboxes,
    iou_threshold=0.5,
    num_classes=2,
)
print("mAP:", mAP)
print("AP per class:", ap_per_class)

Notes

  • Ensure predictions are filtered by NMS before evaluation.
  • Classes with no ground truths are excluded from the mAP average.
  • To approximate COCO-style mAP, run the function over IoU thresholds from 0.50 to 0.95 (step 0.05) and average the resulting mAP values.
  • If you prefer the VOC 2007 11-point interpolation, replace the integration step with sampling precision at recall ∈ {0.0, 0.1, ..., 1.0} and averaging.

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.