Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Instance Segmentation with Mask R-CNN Using MindSpore Framework

Tech May 7 3

Mask R-CNN Overview

Mask R-CNN serves as an elegent and versatile framework for object instance segmentation. Beyond detecting objects within images, it simultaneously generates high-quality segmentation masks for each detected instance. The architecture extends Faster R-CNN by adding a mask prediction branch parallel to the existing bounding box detection branch.

This approach achieves training simplicity with inference speeds reaching 5fps. The computational overhead compared to Faster R-CNN remains minimal. The framework generalizes well to other tasks, including human pose estimation within the same unified architecture. Mask R-CNN demonstrates strong performance across COCO challenge benchmarks, particularly in instance segmentation, bounding box object detection, and person keypoint detection categories.

Network Architecture

Mask R-CNN employs a two-stage detection pipeline as an extension to Faster R-CNN. The architecture adds a dedicated mask prediction branch alongside the existing bounding box detection pathway. The Region Proposal Network (RPN) shares convolutional features across the entire image, enabling efficient candidate region computation without additional computational cost. The backbone network can utilize lightweight alternatives such as MobileNet for reduced computational requirements.

Development Environment Setup

Cloud Platform Configuration

Access ModelArts and initialize a Notebook instance for experimentation. Select the GPU environment for accelerated training, ensuring the appropriate runtime kernel is selected.

Install the required dependencies:

conda update -n base -c defaults conda
conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
pip install mindvision
pip install download

Dependency Imports

import time
import os

import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.nn import layer as L
from mindspore.common.initializer import initializer
from mindspore import context, Tensor, Parameter
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Momentum
from mindspore.common import set_seed

from src.utils.config import config

Dataset Preparation

COCO2017 Dataset

The COCO2017 dataset provides bounding box and pixel-level annotations suitable for scene understanding tasks including semantic segmentation, object detection, and image captioning.

Split Size Image Count
Training 18G 118,000
Validation 1G 5,000

Dataset download: https://cocodataset.org/#download

Data Preprocessing

Standardize image dimensions for consistent processing. Parse annotation JSON files to associate labels with corresponding images.

MindRecord Conversion

Transform raw images into MindRecord format for optimized data loading performance.

from dataset.dataset import create_coco_dataset, data_to_mindrecord_byte_image

def setup_mindrecord_storage(prefix, storage_dir):
    """Initialize MindRecord storage directory."""
    if not os.path.isdir(storage_dir):
        os.makedirs(storage_dir)
    if config.dataset == "coco":
        if os.path.isdir(config.data_root):
            print("Creating MindRecord files...")
            data_to_mindrecord_byte_image("coco", True, prefix)
            print(f"MindRecord created at {storage_dir}")
        else:
            raise RuntimeError("Dataset root directory not found.")
    else:
        if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
            print("Creating MindRecord files...")
            data_to_mindrecord_byte_image("other", True, prefix)
            print(f"MindRecord created at {storage_dir}")
        else:
            raise RuntimeError("Image or annotation files not found.")
    
    while not os.path.exists(mindrecord_file + ".db"):
        time.sleep(5)

Dataset Loading

device_target = config.device_target
rank_id = 0
num_devices = 1
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

print("Initializing dataset...")
record_prefix = "MaskRcnn.mindrecord"
record_dir = config.mindrecord_dir
record_path = os.path.join(record_dir, record_prefix + "0")

if rank_id == 0 and not os.path.exists(record_path):
    setup_mindrecord_storage(record_prefix, record_dir)

data_loader = create_coco_dataset(record_path, 
                                   batch_size=config.batch_size,
                                   device_num=num_devices,
                                   rank_id=rank_id)
dataset_size = data_loader.get_dataset_size()
print(f"Total images: {dataset_size}")
print("Dataset initialization complete.")

Dataset Visualization

import numpy as np
import matplotlib.pyplot as plt

sample_batch = next(data_loader.create_dict_iterator())
sample_images = sample_batch["image"].asnumpy()
print(f'Image dimensions: {sample_images.shape}')

plt.figure()
for idx in range(1, 3):
    plt.subplot(1, 2, idx)
    img_transposed = np.transpose(sample_images[idx - 1], (1, 2, 0))
    img_clipped = np.clip(img_transposed, 0, 1)
    plt.imshow(img_clipped[:, :])
    plt.xticks(rotation=180)
    plt.axis("off")

Model Training

Training Hyperparameters

Parameter Default Description
workers 1 Parallel worker threads
device_target GPU Target compute device
learning_rate 0.002 Initial learning rate
weight_decay 1e-4 L2 regularization factor
total_epoch 13 Training duration
batch_size 2 Samples per batch
checkpoint_path ./ckpt_0 Model save directory

Checkpoint Loading Utility

def initialize_from_checkpoint(network, checkpoint_path, device_target):
    """
    Load pretrained weights into the network.
    
    Args:
        network: Target network architecture
        checkpoint_path: Path to checkpoint file
        device_target: Compute device specification
    """
    param_dict = load_checkpoint(checkpoint_path)
    
    if config.pretrain_epoch_size == 0:
        for key in list(param_dict.keys()):
            if not (key.startswith('backbone') or key.startswith('rcnn_mask')):
                param_dict.pop(key)
        
        if device_target == 'GPU':
            for key, value in param_dict.items():
                tensor = Tensor(value, mstype.float32)
                param_dict[key] = Parameter(tensor, key)
    
    load_param_into_net(network, param_dict)
    return network

Training Implementation

from src.utils.lr_schedule import dynamic_lr

set_seed(1)

def train_maskrcnn_model():
    """Execute Mask R-CNN training pipeline."""
    device_target = config.device_target
    rank_id = 0
    num_devices = 1
    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

    print("Setting up dataset...")
    record_prefix = "MaskRcnn.mindrecord"
    record_dir = config.mindrecord_dir
    record_path = os.path.join(record_dir, record_prefix + "0")
    
    if rank_id == 0 and not os.path.exists(record_path):
        setup_mindrecord_storage(record_prefix, record_dir)
    
    data_loader = create_coco_dataset(record_path,
                                       batch_size=config.batch_size,
                                       device_num=num_devices,
                                       rank_id=rank_id)
    dataset_size = data_loader.get_dataset_size()
    print(f"Dataset contains {dataset_size} images")
    
    network = MaskRcnnResnet50(config=config)
    network.set_train(True)
    
    pretrained_path = config.pre_trained
    if pretrained_path:
        print("Loading pretrained ResNet50 weights")
        network = initialize_from_checkpoint(network, pretrained_path, device_target)
    
    loss_fn = LossNet()
    learning_rate = Tensor(
        dynamic_lr(config, rank_size=num_devices,
                   start_steps=config.pretrain_epoch_size * dataset_size),
        mstype.float32
    )
    
    optimizer = Momentum(
        params=network.trainable_params(),
        learning_rate=learning_rate,
        momentum=config.momentum,
        weight_decay=config.weight_decay,
        loss_scale=config.loss_scale
    )
    
    loss_network = WithLossCell(network, loss_fn)
    training_cell = TrainOneStepCell(loss_network, optimizer, sens=config.loss_scale)
    
    callbacks = [TimeMonitor(data_size=dataset_size), LossCallBack(rank_id=rank_id)]
    
    if config.save_checkpoint:
        save_steps = config.save_checkpoint_epochs * dataset_size
        save_config = CheckpointConfig(
            save_checkpoint_steps=save_steps,
            keep_checkpoint_max=config.keep_checkpoint_max
        )
        save_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank_id) + '/')
        callbacks.append(ModelCheckpoint(prefix='mask_rcnn', directory=save_path, config=save_config))
    
    model = Model(training_cell)
    model.train(config.epoch_size, data_loader, callbacks=callbacks, dataset_sink_mode=False)

if __name__ == '__main__':
    train_maskrcnn_model()

After loading pretrained weights, the loss rapidly stabilizes around 1.0, indicating successful model convergence. Checkpoint files are saved for subsequent fine-tuning and inference operations.

Model Evaluation

from pycocotools.coco import COCO
from src.utils.util import coco_eval, bbox2result_1image, results2json, get_seg_masks

set_seed(1)

def evaluate_model(data_path, checkpoint_path, annotation_path):
    """
    Evaluate trained Mask R-CNN on COCO validation set.
    
    Args:
        data_path: Path to evaluation dataset
        checkpoint_path: Trained model checkpoint
        annotation_path: Ground truth annotations
    """
    eval_loader = create_coco_dataset(
        data_path,
        batch_size=config.test_batch_size,
        is_training=False
    )
    
    network = MaskRcnnResnet50(config)
    params = load_checkpoint(checkpoint_path)
    load_param_into_net(network, params)
    network.set_train(False)
    
    total_batches = eval_loader.get_dataset_size()
    predictions = []
    coco_evaluator = COCO(annotation_path)
    
    print(f"Evaluating {total_batches} images...")
    max_detections = 128
    inference_start = time.time()
    
    for batch_idx, batch_data in enumerate(eval_loader.create_dict_iterator(output_numpy=True, num_epochs=1)):
        img_batch = batch_data['image']
        img_shapes = batch_data['image_shape']
        gt_boxes = batch_data['box']
        gt_labels = batch_data['label']
        gt_counts = batch_data['valid_num']
        gt_masks = batch_data["mask"]
        
        model_output = network(
            Tensor(img_batch),
            Tensor(img_shapes),
            Tensor(gt_boxes),
            Tensor(gt_labels),
            Tensor(gt_counts),
            Tensor(gt_masks)
        )
        
        bbox_out = model_output[0]
        label_out = model_output[1]
        mask_out = model_output[2]
        mask_features = model_output[3]
        
        for j in range(config.test_batch_size):
            bbox_squeezed = np.squeeze(bbox_out.asnumpy()[j, :, :])
            label_squeezed = np.squeeze(label_out.asnumpy()[j, :, :])
            mask_squeezed = np.squeeze(mask_out.asnumpy()[j, :, :])
            feat_squeezed = np.squeeze(mask_features.asnumpy()[j, :, :, :])
            
            valid_boxes = bbox_squeezed[mask_squeezed, :]
            valid_labels = label_squeezed[mask_squeezed]
            valid_features = feat_squeezed[mask_squeezed, :, :]
            
            if valid_boxes.shape[0] > max_detections:
                top_indices = np.argsort(-valid_boxes[:, -1])[:max_detections]
                valid_boxes = valid_boxes[top_indices]
                valid_labels = valid_labels[top_indices]
                valid_features = valid_features[top_indices]
            
            detection_results = bbox2result_1image(valid_boxes, valid_labels, config.num_classes)
            segmentation_results = get_seg_masks(
                valid_features, valid_boxes, valid_labels,
                img_shapes[j], True, config.num_classes
            )
            predictions.append((detection_results, segmentation_results))
    
    inference_time = time.time() - inference_start
    print(f"Evaluation completed in {inference_time:.2f} seconds")
    
    eval_types = ["bbox", "segm"]
    result_files = results2json(coco_evaluator, predictions, "./predictions.pkl")
    coco_eval(result_files, eval_types, coco_evaluator, single_result=False)

def run_evaluation():
    """Execute evaluation workflow."""
    device_target = config.device_target
    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

    config.mindrecord_dir = os.path.join(config.data_root, config.mindrecord_dir)
    record_prefix = "MaskRcnn_eval.mindrecord"
    record_path = os.path.join(config.mindrecord_dir, record_prefix)

    if not os.path.exists(record_path):
        if not os.path.isdir(config.mindrecord_dir):
            os.makedirs(config.mindrecord_dir)
        
        if config.dataset == "coco" and os.path.isdir(config.data_root):
            print("Generating evaluation MindRecord...")
            data_to_mindrecord_byte_image("coco", False, record_prefix, file_num=1)
            print(f"MindRecord saved to {config.mindrecord_dir}")
        elif config.dataset == "other" and os.path.isdir(config.IMAGE_DIR):
            print("Generating evaluation MindRecord...")
            data_to_mindrecord_byte_image("other", False, record_prefix, file_num=1)
            print(f"MindRecord saved to {config.mindrecord_dir}")

    print("Starting evaluation...")
    evaluate_model(record_path, config.checkpoint_path, config.ann_file)
    print(f"Checkpoint: {config.checkpoint_path}")

if __name__ == '__main__':
    run_evaluation()

Inference Pipeline

import random
import colorsys
import matplotlib.pyplot as plt
import matplotlib.patches as patches

set_seed(1)

def create_axes(rows=1, cols=1, figsize=16):
    """Initialize matplotlib axes for visualization."""
    _, axis = plt.subplots(rows, cols, figsize=(figsize * cols, figsize * rows))
    return axis

def convert_to_rgb(img_batch):
    """Transform MindRecord image format to RGB."""
    ch_idx = 0
    normalized = (-np.min(img_batch[ch_idx, :, :, :]) + img_batch[ch_idx, :, :, :]) * \
        255 / (np.max(img_batch[ch_idx, :, :, :]) - np.min(img_batch[ch_idx, :, :, :]))
    uint_img = normalized.astype(np.uint8)
    rgb_image = np.zeros([config.img_height, config.img_width, 3])
    rgb_image[:, :, 0] = uint_img[0, :, :]
    rgb_image[:, :, 1] = uint_img[1, :, :]
    rgb_image[:, :, 2] = uint_img[2, :, :]
    return rgb_image

def generate_distinct_colors(count, bright=True):
    """Create visually distinguishable colors using HSV color space."""
    brightness = 1.0 if bright else 0.7
    hsv_colors = [(i / count, 1, brightness) for i in range(count)]
    rgb_colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv_colors))
    random.shuffle(rgb_colors)
    return rgb_colors

def run_inference():
    """Perform inference on test images."""
    device_target = config.device_target
    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

    record_dir = os.path.join(config.data_root, config.mindrecord_dir)
    record_prefix = "MaskRcnn_eval.mindrecord"
    record_path = os.path.join(record_dir, record_prefix)

    test_loader = create_coco_dataset(
        record_path,
        batch_size=config.test_batch_size,
        is_training=False
    )
    
    total_samples = test_loader.get_dataset_size()
    selected_idx = np.random.choice(total_samples, 1)
    
    model_path = config.checkpoint_path
    network = MaskRcnnResnet50(config)
    checkpoint_params = load_checkpoint(model_path)
    load_param_into_net(network, checkpoint_params)
    network.set_train(False)
    
    test_batch = list(test_loader.create_dict_iterator(output_numpy=True, num_epochs=1))[selected_idx[0]]
    print(f"Processing image index: {selected_idx[0]}")
    
    img_data = test_batch['image']
    img_metas = test_batch['image_shape']
    boxes = test_batch['box']
    labels = test_batch['label']
    valid_counts = test_batch['valid_num']
    masks = test_batch["mask"]
    
    rgb_image = convert_to_rgb(img_data)
    
    inference_start = time.time()
    predictions = network(
        Tensor(img_data),
        Tensor(img_metas),
        Tensor(boxes),
        Tensor(labels),
        Tensor(valid_counts),
        Tensor(masks)
    )
    inference_time = time.time() - inference_start
    print(f"Inference completed in {inference_time:.2f} seconds")
    
    return predictions, rgb_image, img_metas

def visualize_detections(model_output, rgb_image, image_metadata):
    """Render detection and segmentation results."""
    scale_factor = image_metadata[0, 2]
    
    bboxes = model_output[0][0].asnumpy()
    pred_labels = model_output[1][0].asnumpy()
    pred_masks = model_output[2][0].asnumpy()
    
    detection_indices = []
    detection_count = 0
    for idx, mask_flag in enumerate(pred_masks):
        if np.equal(mask_flag, True) and bboxes[idx, 4] > 0.8:
            detection_indices.append(idx)
            detection_count += 1
    
    print(f"Detected {detection_count} instances")
    
    color_palette = generate_distinct_colors(detection_count)
    
    height = config.img_height
    width = config.img_width
    axes = create_axes(1)
    axes.set_ylim(height + 10, -10)
    axes.set_xlim(-10, width + 10)
    axes.axis('off')
    axes.set_title("Detection Results")
    
    display_image = rgb_image.astype(np.uint32).copy()
    
    for inst_idx in range(detection_count):
        color = color_palette[inst_idx]
        box_idx = detection_indices[inst_idx]
        
        x1, y1, x2, y2, conf = bboxes[box_idx] * scale_factor
        
        bbox_rect = patches.Rectangle(
            (x1, y1), x2 - x1, y2 - y1,
            linewidth=2, alpha=0.7,
            linestyle="dashed",
            edgecolor=color,
            facecolor='none'
        )
        axes.add_patch(bbox_rect)
        
        class_name = config.data_classes
        class_id = pred_labels[box_idx, 0].astype(np.uint8) + 1
        confidence = bboxes[box_idx, 4]
        label_text = class_name[class_id]
        
        caption = f"{label_text} {confidence:.3f}"
        axes.text(x1, y1 + 8, caption, color='w', size=11, backgroundcolor="none")
    
    axes.imshow(display_image.astype(np.uint8))
    plt.show()

if __name__ == '__main__':
    predictions, image_rgb, metadata = run_inference()
    visualize_detections(predictions, image_rgb, metadata)

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.