Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Building a Padding-Equipped U-Net for Biomedical Image Segmentation

Tech 1

U-Net, formulated by Ronneberger et al. in 2015, adopts an encoder-decoder framework with skip connections to fuse high-resolution feature maps from the contraction path with upsampled outputs in the expansion path. The architecture excels in tasks requiring precise localization, such as biomedical image segmentation.

The contracting path follows a typical convolutional network topology. It applies two successive 3×3 convolutions (with out padding in the original design), each followed by a ReLU activation, and then a 2×2 max-pooling operation with stride 2 for downsampling. At each downsampling step, the number of feature channels is doubled. The expansive path consists of an upsampling of the feature map followed by a 2×2 convolution ("up-convolution") that halves the number of feature channels, concatenation with the correspondingly cropped feature map from the contracting path, and two 3×3 convolutions, each succeeded by a ReLU. The final layer uses a 1×1 convolution to map each 64-component feature vector to the desired number of classes. In total, the network has 23 convolutional layers.

Why U-Net performs well on medical images:

  • Skip connections recover fine-grained spatial details that are lost during encoding. Deep layers capture coarse, semantic context, while shallow layers retain edges, textures, and other mid-level cues. Concatenating them gives the decoder access to multi-scale information, which is vital for segmenting small structures.
  • The architecture is relatively simple and parameter-efficient, making it well-suited for limited medical datasets where large models tend to overfit.

Common modifications

Two practical adjustments make the model more convenient to implement and often improve accuracy:

  1. Padding the convolutions: Setting padding to 1 in all 3×3 convolutions preserves spatial dimensions, eliminating the need to crop feature maps before skip connections. The output resolution equals the input resolution, which simplifies data loading and loss computation.
  2. Batch normalization: Inserting batch normalization after each convolution accelerates training and stabilizes learning, especially when using higher learning rates.

The following implementation incorporates both modifications. The code is structured as a set of modular building blocks.

Building blocks

import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, mid_c=None, use_bn=True):
        super().__init__()
        if mid_c is None:
            mid_c = out_c
        layers = [
            nn.Conv2d(in_c, mid_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(mid_c) if use_bn else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c) if use_bn else nn.Identity(),
            nn.ReLU(inplace=True)
        ]
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class EncoderStage(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.down = nn.MaxPool2d(2)
        self.conv = ConvBlock(in_c, out_c)

    def forward(self, x):
        return self.conv(self.down(x))

class DecoderStage(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, in_c // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_c, out_c)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        if x1.shape[2:] != x2.shape[2:]:
            diff_y = x2.size(2) - x1.size(2)
            diff_x = x2.size(3) - x1.size(3)
            x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                            diff_y // 2, diff_y - diff_y // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutputLayer(nn.Module):
    def __init__(self, in_c, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(in_c, num_classes, 1)

    def forward(self, x):
        return self.conv(x)

Assembling the U-Net

class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, base_channels=64):
        super().__init__()
        self.inc = ConvBlock(in_channels, base_channels)
        self.enc1 = EncoderStage(base_channels, base_channels * 2)
        self.enc2 = EncoderStage(base_channels * 2, base_channels * 4)
        self.enc3 = EncoderStage(base_channels * 4, base_channels * 8)
        self.enc4 = EncoderStage(base_channels * 8, base_channels * 16)
        self.dec1 = DecoderStage(base_channels * 16, base_channels * 8)
        self.dec2 = DecoderStage(base_channels * 8, base_channels * 4)
        self.dec3 = DecoderStage(base_channels * 4, base_channels * 2)
        self.dec4 = DecoderStage(base_channels * 2, base_channels)
        self.outc = OutputLayer(base_channels, num_classes)

    def forward(self, x):
        e1 = self.inc(x)       # [B, C, H, W]
        e2 = self.enc1(e1)     # [B, 2C, H/2, W/2]
        e3 = self.enc2(e2)     # [B, 4C, H/4, W/4]
        e4 = self.enc3(e3)     # [B, 8C, H/8, W/8]
        e5 = self.enc4(e4)     # [B, 16C, H/16, W/16]
        d1 = self.dec1(e5, e4) # [B, 8C, H/8, W/8]
        d2 = self.dec2(d1, e3) # [B, 4C, H/4, W/4]
        d3 = self.dec3(d2, e2) # [B, 2C, H/2, W/2]
        d4 = self.dec4(d3, e1) # [B, C, H, W]
        logits = self.outc(d4) # [B, num_classes, H, W]
        return logits

Data loading and augmentation

For demonstration, a DRIVE retina vessel segmentation dataset is used. The data pipeline includes random resizing, horizontal/vertical flips, random cropping, normalization, and conversion to tensors. The dataset class reads images and manual annotations, merges them with region-of-interest masks, and applies the transformations.

from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class VesselDataset(Dataset):
    def __init__(self, root, train=True, transforms=None):
        self.mode = 'training' if train else 'test'
        # ... path setup, gather images, manual, mask files ...
        self.transforms = transforms

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert('RGB')
        manual = np.array(Image.open(self.manual_paths[idx]).convert('L')) / 255
        mask = Image.open(self.mask_paths[idx]).convert('L')
        mask = 255 - np.array(mask)
        combined = np.clip(manual + mask, 0, 255).astype(np.uint8)
        combined = Image.fromarray(combined)
        if self.transforms:
            img, combined = self.transforms(img, combined)
        return img, combined

    def __len__(self):
        return len(self.img_paths)

Training loop

A standard PyTorch training script uses SGD with momentum, a learning rate scheduler (warmup + cosine decay), and optional mixed precision. Dice coefficient measured on the validation set guides model selection.

for epoch in range(start_epoch, num_epochs):
    model.train()
    epoch_loss = 0
    for images, targets in train_loader:
        images, targets = images.to(device), targets.to(device)
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(images)
            loss = criterion(outputs, targets)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        lr_scheduler.step()
        epoch_loss += loss.item()
    # validation and checkpoint saving ...

The above code provides a self-contained modern U-Net suitable for binary or multi-class segmentation tasks.

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.