Instance Segmentation with Mask R-CNN Using MindSpore Framework
Mask R-CNN Overview
Mask R-CNN serves as an elegent and versatile framework for object instance segmentation. Beyond detecting objects within images, it simultaneously generates high-quality segmentation masks for each detected instance. The architecture extends Faster R-CNN by adding a mask prediction branch parallel to the existing bounding box detection branch.
This approach achieves training simplicity with inference speeds reaching 5fps. The computational overhead compared to Faster R-CNN remains minimal. The framework generalizes well to other tasks, including human pose estimation within the same unified architecture. Mask R-CNN demonstrates strong performance across COCO challenge benchmarks, particularly in instance segmentation, bounding box object detection, and person keypoint detection categories.
Network Architecture
Mask R-CNN employs a two-stage detection pipeline as an extension to Faster R-CNN. The architecture adds a dedicated mask prediction branch alongside the existing bounding box detection pathway. The Region Proposal Network (RPN) shares convolutional features across the entire image, enabling efficient candidate region computation without additional computational cost. The backbone network can utilize lightweight alternatives such as MobileNet for reduced computational requirements.
Development Environment Setup
Cloud Platform Configuration
Access ModelArts and initialize a Notebook instance for experimentation. Select the GPU environment for accelerated training, ensuring the appropriate runtime kernel is selected.
Install the required dependencies:
conda update -n base -c defaults conda
conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
pip install mindvision
pip install download
Dependency Imports
import time
import os
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.nn import layer as L
from mindspore.common.initializer import initializer
from mindspore import context, Tensor, Parameter
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Momentum
from mindspore.common import set_seed
from src.utils.config import config
Dataset Preparation
COCO2017 Dataset
The COCO2017 dataset provides bounding box and pixel-level annotations suitable for scene understanding tasks including semantic segmentation, object detection, and image captioning.
| Split | Size | Image Count |
|---|---|---|
| Training | 18G | 118,000 |
| Validation | 1G | 5,000 |
Dataset download: https://cocodataset.org/#download
Data Preprocessing
Standardize image dimensions for consistent processing. Parse annotation JSON files to associate labels with corresponding images.
MindRecord Conversion
Transform raw images into MindRecord format for optimized data loading performance.
from dataset.dataset import create_coco_dataset, data_to_mindrecord_byte_image
def setup_mindrecord_storage(prefix, storage_dir):
"""Initialize MindRecord storage directory."""
if not os.path.isdir(storage_dir):
os.makedirs(storage_dir)
if config.dataset == "coco":
if os.path.isdir(config.data_root):
print("Creating MindRecord files...")
data_to_mindrecord_byte_image("coco", True, prefix)
print(f"MindRecord created at {storage_dir}")
else:
raise RuntimeError("Dataset root directory not found.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Creating MindRecord files...")
data_to_mindrecord_byte_image("other", True, prefix)
print(f"MindRecord created at {storage_dir}")
else:
raise RuntimeError("Image or annotation files not found.")
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
Dataset Loading
device_target = config.device_target
rank_id = 0
num_devices = 1
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
print("Initializing dataset...")
record_prefix = "MaskRcnn.mindrecord"
record_dir = config.mindrecord_dir
record_path = os.path.join(record_dir, record_prefix + "0")
if rank_id == 0 and not os.path.exists(record_path):
setup_mindrecord_storage(record_prefix, record_dir)
data_loader = create_coco_dataset(record_path,
batch_size=config.batch_size,
device_num=num_devices,
rank_id=rank_id)
dataset_size = data_loader.get_dataset_size()
print(f"Total images: {dataset_size}")
print("Dataset initialization complete.")
Dataset Visualization
import numpy as np
import matplotlib.pyplot as plt
sample_batch = next(data_loader.create_dict_iterator())
sample_images = sample_batch["image"].asnumpy()
print(f'Image dimensions: {sample_images.shape}')
plt.figure()
for idx in range(1, 3):
plt.subplot(1, 2, idx)
img_transposed = np.transpose(sample_images[idx - 1], (1, 2, 0))
img_clipped = np.clip(img_transposed, 0, 1)
plt.imshow(img_clipped[:, :])
plt.xticks(rotation=180)
plt.axis("off")
Model Training
Training Hyperparameters
| Parameter | Default | Description |
|---|---|---|
| workers | 1 | Parallel worker threads |
| device_target | GPU | Target compute device |
| learning_rate | 0.002 | Initial learning rate |
| weight_decay | 1e-4 | L2 regularization factor |
| total_epoch | 13 | Training duration |
| batch_size | 2 | Samples per batch |
| checkpoint_path | ./ckpt_0 | Model save directory |
Checkpoint Loading Utility
def initialize_from_checkpoint(network, checkpoint_path, device_target):
"""
Load pretrained weights into the network.
Args:
network: Target network architecture
checkpoint_path: Path to checkpoint file
device_target: Compute device specification
"""
param_dict = load_checkpoint(checkpoint_path)
if config.pretrain_epoch_size == 0:
for key in list(param_dict.keys()):
if not (key.startswith('backbone') or key.startswith('rcnn_mask')):
param_dict.pop(key)
if device_target == 'GPU':
for key, value in param_dict.items():
tensor = Tensor(value, mstype.float32)
param_dict[key] = Parameter(tensor, key)
load_param_into_net(network, param_dict)
return network
Training Implementation
from src.utils.lr_schedule import dynamic_lr
set_seed(1)
def train_maskrcnn_model():
"""Execute Mask R-CNN training pipeline."""
device_target = config.device_target
rank_id = 0
num_devices = 1
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
print("Setting up dataset...")
record_prefix = "MaskRcnn.mindrecord"
record_dir = config.mindrecord_dir
record_path = os.path.join(record_dir, record_prefix + "0")
if rank_id == 0 and not os.path.exists(record_path):
setup_mindrecord_storage(record_prefix, record_dir)
data_loader = create_coco_dataset(record_path,
batch_size=config.batch_size,
device_num=num_devices,
rank_id=rank_id)
dataset_size = data_loader.get_dataset_size()
print(f"Dataset contains {dataset_size} images")
network = MaskRcnnResnet50(config=config)
network.set_train(True)
pretrained_path = config.pre_trained
if pretrained_path:
print("Loading pretrained ResNet50 weights")
network = initialize_from_checkpoint(network, pretrained_path, device_target)
loss_fn = LossNet()
learning_rate = Tensor(
dynamic_lr(config, rank_size=num_devices,
start_steps=config.pretrain_epoch_size * dataset_size),
mstype.float32
)
optimizer = Momentum(
params=network.trainable_params(),
learning_rate=learning_rate,
momentum=config.momentum,
weight_decay=config.weight_decay,
loss_scale=config.loss_scale
)
loss_network = WithLossCell(network, loss_fn)
training_cell = TrainOneStepCell(loss_network, optimizer, sens=config.loss_scale)
callbacks = [TimeMonitor(data_size=dataset_size), LossCallBack(rank_id=rank_id)]
if config.save_checkpoint:
save_steps = config.save_checkpoint_epochs * dataset_size
save_config = CheckpointConfig(
save_checkpoint_steps=save_steps,
keep_checkpoint_max=config.keep_checkpoint_max
)
save_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank_id) + '/')
callbacks.append(ModelCheckpoint(prefix='mask_rcnn', directory=save_path, config=save_config))
model = Model(training_cell)
model.train(config.epoch_size, data_loader, callbacks=callbacks, dataset_sink_mode=False)
if __name__ == '__main__':
train_maskrcnn_model()
After loading pretrained weights, the loss rapidly stabilizes around 1.0, indicating successful model convergence. Checkpoint files are saved for subsequent fine-tuning and inference operations.
Model Evaluation
from pycocotools.coco import COCO
from src.utils.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
set_seed(1)
def evaluate_model(data_path, checkpoint_path, annotation_path):
"""
Evaluate trained Mask R-CNN on COCO validation set.
Args:
data_path: Path to evaluation dataset
checkpoint_path: Trained model checkpoint
annotation_path: Ground truth annotations
"""
eval_loader = create_coco_dataset(
data_path,
batch_size=config.test_batch_size,
is_training=False
)
network = MaskRcnnResnet50(config)
params = load_checkpoint(checkpoint_path)
load_param_into_net(network, params)
network.set_train(False)
total_batches = eval_loader.get_dataset_size()
predictions = []
coco_evaluator = COCO(annotation_path)
print(f"Evaluating {total_batches} images...")
max_detections = 128
inference_start = time.time()
for batch_idx, batch_data in enumerate(eval_loader.create_dict_iterator(output_numpy=True, num_epochs=1)):
img_batch = batch_data['image']
img_shapes = batch_data['image_shape']
gt_boxes = batch_data['box']
gt_labels = batch_data['label']
gt_counts = batch_data['valid_num']
gt_masks = batch_data["mask"]
model_output = network(
Tensor(img_batch),
Tensor(img_shapes),
Tensor(gt_boxes),
Tensor(gt_labels),
Tensor(gt_counts),
Tensor(gt_masks)
)
bbox_out = model_output[0]
label_out = model_output[1]
mask_out = model_output[2]
mask_features = model_output[3]
for j in range(config.test_batch_size):
bbox_squeezed = np.squeeze(bbox_out.asnumpy()[j, :, :])
label_squeezed = np.squeeze(label_out.asnumpy()[j, :, :])
mask_squeezed = np.squeeze(mask_out.asnumpy()[j, :, :])
feat_squeezed = np.squeeze(mask_features.asnumpy()[j, :, :, :])
valid_boxes = bbox_squeezed[mask_squeezed, :]
valid_labels = label_squeezed[mask_squeezed]
valid_features = feat_squeezed[mask_squeezed, :, :]
if valid_boxes.shape[0] > max_detections:
top_indices = np.argsort(-valid_boxes[:, -1])[:max_detections]
valid_boxes = valid_boxes[top_indices]
valid_labels = valid_labels[top_indices]
valid_features = valid_features[top_indices]
detection_results = bbox2result_1image(valid_boxes, valid_labels, config.num_classes)
segmentation_results = get_seg_masks(
valid_features, valid_boxes, valid_labels,
img_shapes[j], True, config.num_classes
)
predictions.append((detection_results, segmentation_results))
inference_time = time.time() - inference_start
print(f"Evaluation completed in {inference_time:.2f} seconds")
eval_types = ["bbox", "segm"]
result_files = results2json(coco_evaluator, predictions, "./predictions.pkl")
coco_eval(result_files, eval_types, coco_evaluator, single_result=False)
def run_evaluation():
"""Execute evaluation workflow."""
device_target = config.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
config.mindrecord_dir = os.path.join(config.data_root, config.mindrecord_dir)
record_prefix = "MaskRcnn_eval.mindrecord"
record_path = os.path.join(config.mindrecord_dir, record_prefix)
if not os.path.exists(record_path):
if not os.path.isdir(config.mindrecord_dir):
os.makedirs(config.mindrecord_dir)
if config.dataset == "coco" and os.path.isdir(config.data_root):
print("Generating evaluation MindRecord...")
data_to_mindrecord_byte_image("coco", False, record_prefix, file_num=1)
print(f"MindRecord saved to {config.mindrecord_dir}")
elif config.dataset == "other" and os.path.isdir(config.IMAGE_DIR):
print("Generating evaluation MindRecord...")
data_to_mindrecord_byte_image("other", False, record_prefix, file_num=1)
print(f"MindRecord saved to {config.mindrecord_dir}")
print("Starting evaluation...")
evaluate_model(record_path, config.checkpoint_path, config.ann_file)
print(f"Checkpoint: {config.checkpoint_path}")
if __name__ == '__main__':
run_evaluation()
Inference Pipeline
import random
import colorsys
import matplotlib.pyplot as plt
import matplotlib.patches as patches
set_seed(1)
def create_axes(rows=1, cols=1, figsize=16):
"""Initialize matplotlib axes for visualization."""
_, axis = plt.subplots(rows, cols, figsize=(figsize * cols, figsize * rows))
return axis
def convert_to_rgb(img_batch):
"""Transform MindRecord image format to RGB."""
ch_idx = 0
normalized = (-np.min(img_batch[ch_idx, :, :, :]) + img_batch[ch_idx, :, :, :]) * \
255 / (np.max(img_batch[ch_idx, :, :, :]) - np.min(img_batch[ch_idx, :, :, :]))
uint_img = normalized.astype(np.uint8)
rgb_image = np.zeros([config.img_height, config.img_width, 3])
rgb_image[:, :, 0] = uint_img[0, :, :]
rgb_image[:, :, 1] = uint_img[1, :, :]
rgb_image[:, :, 2] = uint_img[2, :, :]
return rgb_image
def generate_distinct_colors(count, bright=True):
"""Create visually distinguishable colors using HSV color space."""
brightness = 1.0 if bright else 0.7
hsv_colors = [(i / count, 1, brightness) for i in range(count)]
rgb_colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv_colors))
random.shuffle(rgb_colors)
return rgb_colors
def run_inference():
"""Perform inference on test images."""
device_target = config.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
record_dir = os.path.join(config.data_root, config.mindrecord_dir)
record_prefix = "MaskRcnn_eval.mindrecord"
record_path = os.path.join(record_dir, record_prefix)
test_loader = create_coco_dataset(
record_path,
batch_size=config.test_batch_size,
is_training=False
)
total_samples = test_loader.get_dataset_size()
selected_idx = np.random.choice(total_samples, 1)
model_path = config.checkpoint_path
network = MaskRcnnResnet50(config)
checkpoint_params = load_checkpoint(model_path)
load_param_into_net(network, checkpoint_params)
network.set_train(False)
test_batch = list(test_loader.create_dict_iterator(output_numpy=True, num_epochs=1))[selected_idx[0]]
print(f"Processing image index: {selected_idx[0]}")
img_data = test_batch['image']
img_metas = test_batch['image_shape']
boxes = test_batch['box']
labels = test_batch['label']
valid_counts = test_batch['valid_num']
masks = test_batch["mask"]
rgb_image = convert_to_rgb(img_data)
inference_start = time.time()
predictions = network(
Tensor(img_data),
Tensor(img_metas),
Tensor(boxes),
Tensor(labels),
Tensor(valid_counts),
Tensor(masks)
)
inference_time = time.time() - inference_start
print(f"Inference completed in {inference_time:.2f} seconds")
return predictions, rgb_image, img_metas
def visualize_detections(model_output, rgb_image, image_metadata):
"""Render detection and segmentation results."""
scale_factor = image_metadata[0, 2]
bboxes = model_output[0][0].asnumpy()
pred_labels = model_output[1][0].asnumpy()
pred_masks = model_output[2][0].asnumpy()
detection_indices = []
detection_count = 0
for idx, mask_flag in enumerate(pred_masks):
if np.equal(mask_flag, True) and bboxes[idx, 4] > 0.8:
detection_indices.append(idx)
detection_count += 1
print(f"Detected {detection_count} instances")
color_palette = generate_distinct_colors(detection_count)
height = config.img_height
width = config.img_width
axes = create_axes(1)
axes.set_ylim(height + 10, -10)
axes.set_xlim(-10, width + 10)
axes.axis('off')
axes.set_title("Detection Results")
display_image = rgb_image.astype(np.uint32).copy()
for inst_idx in range(detection_count):
color = color_palette[inst_idx]
box_idx = detection_indices[inst_idx]
x1, y1, x2, y2, conf = bboxes[box_idx] * scale_factor
bbox_rect = patches.Rectangle(
(x1, y1), x2 - x1, y2 - y1,
linewidth=2, alpha=0.7,
linestyle="dashed",
edgecolor=color,
facecolor='none'
)
axes.add_patch(bbox_rect)
class_name = config.data_classes
class_id = pred_labels[box_idx, 0].astype(np.uint8) + 1
confidence = bboxes[box_idx, 4]
label_text = class_name[class_id]
caption = f"{label_text} {confidence:.3f}"
axes.text(x1, y1 + 8, caption, color='w', size=11, backgroundcolor="none")
axes.imshow(display_image.astype(np.uint8))
plt.show()
if __name__ == '__main__':
predictions, image_rgb, metadata = run_inference()
visualize_detections(predictions, image_rgb, metadata)