Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Faster R-CNN Implementation Code Walkthrough with Mobilenet Backbone

Tech 1

RPN Layer Overview

The Region Proposal Network (RPN) generates candidate bounding boxes (proposals) from feature maps extracted by the backbone network. This section breaks down core RPN componentts including anchor generation, proposal prediction, and filtering.

Anchor Generation Module

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.ops.boxes as box_ops
import det_utils

class AnchorTemplateGenerator(nn.Module):
    def __init__(self, base_sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1.0, 2.0)):
        super().__init__()
        if not isinstance(base_sizes[0], (list, tuple)):
            base_sizes = tuple((s,) for s in base_sizes)
        if not isinstance(aspect_ratios[0], (list, tuple)):
            aspect_ratios = (aspect_ratios,) * len(base_sizes)
        assert len(base_sizes) == len(aspect_ratios)
        
        self.base_sizes = base_sizes
        self.aspect_ratios = aspect_ratios
        self.template_cache = None
        self.grid_cache = {}

    def _compute_template_anchors(self, sizes, ratios, dtype, device):
        scales = torch.tensor(sizes, dtype=dtype, device=device)
        ars = torch.tensor(ratios, dtype=dtype, device=device)
        h_factors = torch.sqrt(ars)
        w_factors = 1.0 / h_factors
        
        widths = (w_factors[:, None] * scales[None, :]).flatten()
        heights = (h_factors[:, None] * scales[None, :]).flatten()
        
        templates = torch.stack([-widths, -heights, widths, heights], dim=1) / 2
        return templates.round()

    def _set_template_anchors(self, dtype, device):
        if self.template_cache is not None and self.template_cache[0].device == device:
            return
        
        self.template_cache = [
            self._compute_template_anchors(sizes, ratios, dtype, device)
            for sizes, ratios in zip(self.base_sizes, self.aspect_ratios)
        ]

    def _get_grid_anchors(self, grid_dims, strides):
        anchors = []
        templates = self.template_cache
        assert templates is not None

        for (h, w), (sh, sw), base in zip(grid_dims, strides, templates):
            shift_x = torch.arange(0, w, dtype=torch.float32, device=base.device) * sw
            shift_y = torch.arange(0, h, dtype=torch.float32, device=base.device) * sh
            
            mesh_y, mesh_x = torch.meshgrid(shift_y, shift_x)
            flat_x = mesh_x.flatten()
            flat_y = mesh_y.flatten()
            
            offsets = torch.stack([flat_x, flat_y, flat_x, flat_y], dim=1)
            shifted_anchors = offsets.view(-1, 1, 4) + base.view(1, -1, 4)
            anchors.append(shifted_anchors.flatten(0, 1))
        
        return anchors

    def forward(self, img_list, feat_maps):
        grid_dims = [f.shape[-2:] for f in feat_maps]
        img_size = img_list.tensors.shape[-2:]
        dtype, device = feat_maps[0].dtype, feat_maps[0].device
        
        strides = [
            [
                torch.tensor(img_size[0] // h, dtype=torch.int64, device=device),
                torch.tensor(img_size[1] // w, dtype=torch.int64, device=device)
            ]
            for (h, w) in grid_dims
        ]
        
        self._set_template_anchors(dtype, device)
        
        cache_key = str(grid_dims) + str(strides)
        if cache_key not in self.grid_cache:
            self.grid_cache[cache_key] = self._get_grid_anchors(grid_dims, strides)
        
        anchors_per_img = []
        for _, (img_h, img_w) in enumerate(img_list.image_sizes):
            img_anchors = []
            for per_feat_anchors in self.grid_cache[cache_key]:
                img_anchors.append(per_feat_anchors)
            anchors_per_img.append(torch.cat(img_anchors))
        
        self.grid_cache.clear()
        return anchors_per_img

RPN Head Module

class RPNPredictionHead(nn.Module):
    def __init__(self, input_channels, anchor_count_per_pixel):
        super().__init__()
        self.conv_slider = nn.Conv2d(input_channels, input_channels, 3, 1, 1)
        self.cls_logits = nn.Conv2d(input_channels, anchor_count_per_pixel, 1, 1)
        self.reg_deltas = nn.Conv2d(input_channels, anchor_count_per_pixel * 4, 1, 1)
        
        for layer in self.children():
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, std=0.01)
                nn.init.constant_(layer.bias, 0)

    def forward(self, features):
        objectness_logits = []
        box_reg_deltas = []
        
        for feat in features:
            x = F.relu(self.conv_slider(feat))
            objectness_logits.append(self.cls_logits(x))
            box_reg_deltas.append(self.reg_deltas(x))
        
        return objectness_logits, box_reg_deltas

Region Proposal Network

class RegionProposalNetwork(nn.Module):
    def __init__(self, anchor_gen, rpn_head, fg_iou=0.7, bg_iou=0.3,
                 batch_size=256, pos_fraction=0.5, pre_nms_top=2000, post_nms_top=1000,
                 nms_thresh=0.7, score_thresh=0.0):
        super().__init__()
        self.anchor_gen = anchor_gen
        self.rpn_head = rpn_head
        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
        
        self.iou_calc = box_ops.box_iou
        self.matcher = det_utils.Matcher(fg_iou, bg_iou, allow_low_quality=True)
        self.sampler = det_utils.BalancedPositiveNegativeSampler(batch_size, pos_fraction)
        
        self.pre_nms_top_k = pre_nms_top
        self.post_nms_top_k = post_nms_top
        self.nms_thresh = nms_thresh
        self.score_thresh = score_thresh
        self.min_box_size = 1.0

    def _decode_proposals(self, rel_deltas, anchors):
        boxes_per_img = [b.size(0) for b in anchors]
        concat_anchors = torch.cat(anchors, dim=0)
        pred_boxes = self.box_coder.decode_single(rel_deltas, concat_anchors)
        
        if sum(boxes_per_img) > 0:
            pred_boxes = pred_boxes.reshape(sum(boxes_per_img), -1, 4)
        return pred_boxes

    def _filter_proposals(self, proposals, objectness, img_shapes, anchors_per_feat):
        num_imgs = proposals.shape[0]
        device = proposals.device
        
        objectness = objectness.detach().reshape(num_imgs, -1)
        
        feat_levels = [
            torch.full((n,), idx, dtype=torch.int64, device=device)
            for idx, n in enumerate(anchors_per_feat)
        ]
        feat_levels = torch.cat(feat_levels, 0).reshape(1, -1).expand(num_imgs, -1)
        
        final_boxes = []
        final_scores = []
        
        for i in range(num_imgs):
            img_proposals = proposals[i]
            img_logits = objectness[i]
            img_levels = feat_levels[i]
            img_h, img_w = img_shapes[i]
            
            # NMS and filtering
            img_proposals = box_ops.clip_boxes_to_image(img_proposals, (img_h, img_w))
            valid_mask = box_ops.remove_small_boxes(img_proposals, self.min_box_size)
            img_proposals, img_logits, img_levels = img_proposals[valid_mask], img_logits[valid_mask], img_levels[valid_mask]
            
            score_mask = img_logits.sigmoid() >= self.score_thresh
            img_proposals, img_logits, img_levels = img_proposals[score_mask], img_logits[score_mask], img_levels[score_mask]
            
            keep = box_ops.batched_nms(img_proposals, img_logits.sigmoid(), img_levels, self.nms_thresh)
            keep = keep[:self.post_nms_top_k]
            
            final_boxes.append(img_proposals[keep])
            final_scores.append(img_logits.sigmoid()[keep])
        
        return final_boxes, final_scores

    def forward(self, img_list, features, targets=None):
        features = list(features.values())
        logits, deltas = self.rpn_head(features)
        anchors = self.anchor_gen(img_list, features)
        
        num_imgs = len(anchors)
        anchors_per_feat = [o[0].shape[0] * o[0].shape[1] * o[0].shape[2] for o in logits]
        logits, deltas = det_utils.concat_box_prediction_layers(logits, deltas)
        
        proposals = self.box_coder.decode(deltas.detach(), anchors)
        proposals = proposals.view(num_imgs, -1, 4)
        
        boxes, scores = self._filter_proposals(proposals, logits, img_list.image_sizes, anchors_per_feat)
        
        losses = {}
        if self.training:
            assert targets is not None
            labels, matched_boxes = self._assign_anchors(anchors, targets)
            regression_targets = self.box_coder.encode(matched_boxes, anchors)
            loss_obj, loss_reg = self._compute_loss(logits, deltas, labels, regression_targets)
            losses = {"loss_rpn_obj": loss_obj, "loss_rpn_reg": loss_reg}
        
        return boxes, losses

Box Coder Helper (det_utils)

class BoxCoder:
    def __init__(self, weights=(10., 10., 5., 5.), bbox_clip=math.log(1000. / 16)):
        self.weights = weights
        self.bbox_clip = bbox_clip

    def encode_single(self, gt_boxes, anchors):
        widths = anchors[:, 2] - anchors[:, 0]
        heights = anchors[:, 3] - anchors[:, 1]
        ctr_x = anchors[:, 0] + 0.5 * widths
        ctr_y = anchors[:, 1] + 0.5 * heights
        
        gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0]
        gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1]
        gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_widths
        gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_heights
        
        wx, wy, ww, wh = self.weights
        dx = (gt_ctr_x - ctr_x) / widths / wx
        dy = (gt_ctr_y - ctr_y) / heights / wy
        dw = torch.log(gt_widths / widths) / ww
        dh = torch.log(gt_heights / heights) / wh
        
        return torch.stack([dx, dy, dw, dh], dim=1)

    def decode_single(self, deltas, anchors):
        anchors = anchors.to(deltas.dtype)
        widths = anchors[:, 2] - anchors[:, 0]
        heights = anchors[:, 3] - anchors[:, 1]
        ctr_x = anchors[:, 0] + 0.5 * widths
        ctr_y = anchors[:, 1] + 0.5 * heights
        
        wx, wy, ww, wh = self.weights
        dx = deltas[:, 0::4] / wx
        dy = deltas[:, 1::4] / wy
        dw = deltas[:, 2::4] / ww
        dh = deltas[:, 3::4] / wh
        
        dw = torch.clamp(dw, max=self.bbox_clip)
        dh = torch.clamp(dh, max=self.bbox_clip)
        
        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
        pred_w = torch.exp(dw) * widths[:, None]
        pred_h = torch.exp(dh) * heights[:, None]
        
        pred_x1 = pred_ctr_x - 0.5 * pred_w
        pred_y1 = pred_ctr_y - 0.5 * pred_h
        pred_x2 = pred_ctr_x + 0.5 * pred_w
        pred_y2 = pred_ctr_y + 0.5 * pred_h
        
        return torch.stack([pred_x1, pred_y1, pred_x2, pred_y2], dim=2).flatten(1)

    def encode(self, gt_boxes_list, anchors_list):
        all_targets = []
        for gt_boxes, anchors in zip(gt_boxes_list, anchors_list):
            targets = self.encode_single(gt_boxes, anchors)
            all_targets.append(targets)
        return all_targets

    def decode(self, deltas, anchors_list):
        boxes_per_img = [a.size(0) for a in anchors_list]
        concat_anchors = torch.cat(anchors_list, dim=0)
        pred_boxes = self.decode_single(deltas, concat_anchors)
        if sum(boxes_per_img) > 0:
            pred_boxes = pred_boxes.reshape(sum(boxes_per_img), -1, 4)
        return pred_boxes

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.