Faster R-CNN Implementation Code Walkthrough with Mobilenet Backbone
RPN Layer Overview
The Region Proposal Network (RPN) generates candidate bounding boxes (proposals) from feature maps extracted by the backbone network. This section breaks down core RPN componentts including anchor generation, proposal prediction, and filtering.
Anchor Generation Module
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.ops.boxes as box_ops
import det_utils
class AnchorTemplateGenerator(nn.Module):
def __init__(self, base_sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1.0, 2.0)):
super().__init__()
if not isinstance(base_sizes[0], (list, tuple)):
base_sizes = tuple((s,) for s in base_sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(base_sizes)
assert len(base_sizes) == len(aspect_ratios)
self.base_sizes = base_sizes
self.aspect_ratios = aspect_ratios
self.template_cache = None
self.grid_cache = {}
def _compute_template_anchors(self, sizes, ratios, dtype, device):
scales = torch.tensor(sizes, dtype=dtype, device=device)
ars = torch.tensor(ratios, dtype=dtype, device=device)
h_factors = torch.sqrt(ars)
w_factors = 1.0 / h_factors
widths = (w_factors[:, None] * scales[None, :]).flatten()
heights = (h_factors[:, None] * scales[None, :]).flatten()
templates = torch.stack([-widths, -heights, widths, heights], dim=1) / 2
return templates.round()
def _set_template_anchors(self, dtype, device):
if self.template_cache is not None and self.template_cache[0].device == device:
return
self.template_cache = [
self._compute_template_anchors(sizes, ratios, dtype, device)
for sizes, ratios in zip(self.base_sizes, self.aspect_ratios)
]
def _get_grid_anchors(self, grid_dims, strides):
anchors = []
templates = self.template_cache
assert templates is not None
for (h, w), (sh, sw), base in zip(grid_dims, strides, templates):
shift_x = torch.arange(0, w, dtype=torch.float32, device=base.device) * sw
shift_y = torch.arange(0, h, dtype=torch.float32, device=base.device) * sh
mesh_y, mesh_x = torch.meshgrid(shift_y, shift_x)
flat_x = mesh_x.flatten()
flat_y = mesh_y.flatten()
offsets = torch.stack([flat_x, flat_y, flat_x, flat_y], dim=1)
shifted_anchors = offsets.view(-1, 1, 4) + base.view(1, -1, 4)
anchors.append(shifted_anchors.flatten(0, 1))
return anchors
def forward(self, img_list, feat_maps):
grid_dims = [f.shape[-2:] for f in feat_maps]
img_size = img_list.tensors.shape[-2:]
dtype, device = feat_maps[0].dtype, feat_maps[0].device
strides = [
[
torch.tensor(img_size[0] // h, dtype=torch.int64, device=device),
torch.tensor(img_size[1] // w, dtype=torch.int64, device=device)
]
for (h, w) in grid_dims
]
self._set_template_anchors(dtype, device)
cache_key = str(grid_dims) + str(strides)
if cache_key not in self.grid_cache:
self.grid_cache[cache_key] = self._get_grid_anchors(grid_dims, strides)
anchors_per_img = []
for _, (img_h, img_w) in enumerate(img_list.image_sizes):
img_anchors = []
for per_feat_anchors in self.grid_cache[cache_key]:
img_anchors.append(per_feat_anchors)
anchors_per_img.append(torch.cat(img_anchors))
self.grid_cache.clear()
return anchors_per_img
RPN Head Module
class RPNPredictionHead(nn.Module):
def __init__(self, input_channels, anchor_count_per_pixel):
super().__init__()
self.conv_slider = nn.Conv2d(input_channels, input_channels, 3, 1, 1)
self.cls_logits = nn.Conv2d(input_channels, anchor_count_per_pixel, 1, 1)
self.reg_deltas = nn.Conv2d(input_channels, anchor_count_per_pixel * 4, 1, 1)
for layer in self.children():
if isinstance(layer, nn.Conv2d):
nn.init.normal_(layer.weight, std=0.01)
nn.init.constant_(layer.bias, 0)
def forward(self, features):
objectness_logits = []
box_reg_deltas = []
for feat in features:
x = F.relu(self.conv_slider(feat))
objectness_logits.append(self.cls_logits(x))
box_reg_deltas.append(self.reg_deltas(x))
return objectness_logits, box_reg_deltas
Region Proposal Network
class RegionProposalNetwork(nn.Module):
def __init__(self, anchor_gen, rpn_head, fg_iou=0.7, bg_iou=0.3,
batch_size=256, pos_fraction=0.5, pre_nms_top=2000, post_nms_top=1000,
nms_thresh=0.7, score_thresh=0.0):
super().__init__()
self.anchor_gen = anchor_gen
self.rpn_head = rpn_head
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
self.iou_calc = box_ops.box_iou
self.matcher = det_utils.Matcher(fg_iou, bg_iou, allow_low_quality=True)
self.sampler = det_utils.BalancedPositiveNegativeSampler(batch_size, pos_fraction)
self.pre_nms_top_k = pre_nms_top
self.post_nms_top_k = post_nms_top
self.nms_thresh = nms_thresh
self.score_thresh = score_thresh
self.min_box_size = 1.0
def _decode_proposals(self, rel_deltas, anchors):
boxes_per_img = [b.size(0) for b in anchors]
concat_anchors = torch.cat(anchors, dim=0)
pred_boxes = self.box_coder.decode_single(rel_deltas, concat_anchors)
if sum(boxes_per_img) > 0:
pred_boxes = pred_boxes.reshape(sum(boxes_per_img), -1, 4)
return pred_boxes
def _filter_proposals(self, proposals, objectness, img_shapes, anchors_per_feat):
num_imgs = proposals.shape[0]
device = proposals.device
objectness = objectness.detach().reshape(num_imgs, -1)
feat_levels = [
torch.full((n,), idx, dtype=torch.int64, device=device)
for idx, n in enumerate(anchors_per_feat)
]
feat_levels = torch.cat(feat_levels, 0).reshape(1, -1).expand(num_imgs, -1)
final_boxes = []
final_scores = []
for i in range(num_imgs):
img_proposals = proposals[i]
img_logits = objectness[i]
img_levels = feat_levels[i]
img_h, img_w = img_shapes[i]
# NMS and filtering
img_proposals = box_ops.clip_boxes_to_image(img_proposals, (img_h, img_w))
valid_mask = box_ops.remove_small_boxes(img_proposals, self.min_box_size)
img_proposals, img_logits, img_levels = img_proposals[valid_mask], img_logits[valid_mask], img_levels[valid_mask]
score_mask = img_logits.sigmoid() >= self.score_thresh
img_proposals, img_logits, img_levels = img_proposals[score_mask], img_logits[score_mask], img_levels[score_mask]
keep = box_ops.batched_nms(img_proposals, img_logits.sigmoid(), img_levels, self.nms_thresh)
keep = keep[:self.post_nms_top_k]
final_boxes.append(img_proposals[keep])
final_scores.append(img_logits.sigmoid()[keep])
return final_boxes, final_scores
def forward(self, img_list, features, targets=None):
features = list(features.values())
logits, deltas = self.rpn_head(features)
anchors = self.anchor_gen(img_list, features)
num_imgs = len(anchors)
anchors_per_feat = [o[0].shape[0] * o[0].shape[1] * o[0].shape[2] for o in logits]
logits, deltas = det_utils.concat_box_prediction_layers(logits, deltas)
proposals = self.box_coder.decode(deltas.detach(), anchors)
proposals = proposals.view(num_imgs, -1, 4)
boxes, scores = self._filter_proposals(proposals, logits, img_list.image_sizes, anchors_per_feat)
losses = {}
if self.training:
assert targets is not None
labels, matched_boxes = self._assign_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_boxes, anchors)
loss_obj, loss_reg = self._compute_loss(logits, deltas, labels, regression_targets)
losses = {"loss_rpn_obj": loss_obj, "loss_rpn_reg": loss_reg}
return boxes, losses
Box Coder Helper (det_utils)
class BoxCoder:
def __init__(self, weights=(10., 10., 5., 5.), bbox_clip=math.log(1000. / 16)):
self.weights = weights
self.bbox_clip = bbox_clip
def encode_single(self, gt_boxes, anchors):
widths = anchors[:, 2] - anchors[:, 0]
heights = anchors[:, 3] - anchors[:, 1]
ctr_x = anchors[:, 0] + 0.5 * widths
ctr_y = anchors[:, 1] + 0.5 * heights
gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0]
gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1]
gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_widths
gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_heights
wx, wy, ww, wh = self.weights
dx = (gt_ctr_x - ctr_x) / widths / wx
dy = (gt_ctr_y - ctr_y) / heights / wy
dw = torch.log(gt_widths / widths) / ww
dh = torch.log(gt_heights / heights) / wh
return torch.stack([dx, dy, dw, dh], dim=1)
def decode_single(self, deltas, anchors):
anchors = anchors.to(deltas.dtype)
widths = anchors[:, 2] - anchors[:, 0]
heights = anchors[:, 3] - anchors[:, 1]
ctr_x = anchors[:, 0] + 0.5 * widths
ctr_y = anchors[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.weights
dx = deltas[:, 0::4] / wx
dy = deltas[:, 1::4] / wy
dw = deltas[:, 2::4] / ww
dh = deltas[:, 3::4] / wh
dw = torch.clamp(dw, max=self.bbox_clip)
dh = torch.clamp(dh, max=self.bbox_clip)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_x1 = pred_ctr_x - 0.5 * pred_w
pred_y1 = pred_ctr_y - 0.5 * pred_h
pred_x2 = pred_ctr_x + 0.5 * pred_w
pred_y2 = pred_ctr_y + 0.5 * pred_h
return torch.stack([pred_x1, pred_y1, pred_x2, pred_y2], dim=2).flatten(1)
def encode(self, gt_boxes_list, anchors_list):
all_targets = []
for gt_boxes, anchors in zip(gt_boxes_list, anchors_list):
targets = self.encode_single(gt_boxes, anchors)
all_targets.append(targets)
return all_targets
def decode(self, deltas, anchors_list):
boxes_per_img = [a.size(0) for a in anchors_list]
concat_anchors = torch.cat(anchors_list, dim=0)
pred_boxes = self.decode_single(deltas, concat_anchors)
if sum(boxes_per_img) > 0:
pred_boxes = pred_boxes.reshape(sum(boxes_per_img), -1, 4)
return pred_boxes