Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Transfer Learning with ResNet50 for Dogs and Wolves Classification

Tech May 8 5

Transfer Learning with ResNet50

When training deep learning models for specific tasks, collecting large amounts of training data is often impractical. Instead of training from scratch, a common approach is to use a pre-trained model—typically one trained on a large foundational dataset—and adapt it for the target application.

This tutorial demonstrates transfer learning by fine-tuning a ResNet50 model pre-trained on ImageNet to classify images of wolves and dogs.

Dataset Preparation

Downloading the Dataset

The Canidae dataset contains images extracted from ImageNet, with approximately 120 training images and 30 validation images per category.

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"

download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)

Loading the Dataset

Using MindSpore's ImageFolderDataset interface to load the dataset with appropriate augmentation for training and inference.

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

batch_size = 18
image_size = 224
num_epochs = 5
learning_rate = 0.001
momentum = 0.9
num_workers = 4

train_data_path = "./datasets-Canidae/data/Canidae/train/"
val_data_path = "./datasets-Canidae/data/Canidae/val/"

def create_dataset_canidae(dataset_path, usage):
    """Load and preprocess dataset"""
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=num_workers,
                                     shuffle=True)

    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale_offset = 32

    if usage == "train":
        transforms = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        transforms = [
            vision.Decode(),
            vision.Resize(image_size + scale_offset),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]

    data_set = data_set.map(
        operations=transforms,
        input_columns='image',
        num_parallel_workers=num_workers)

    data_set = data_set.batch(batch_size)

    return data_set

train_dataset = create_dataset_canidae(train_data_path, "train")
train_steps = train_dataset.get_dataset_size()

val_dataset = create_dataset_canidae(val_data_path, "val")
val_steps = val_dataset.get_dataset_size()

Visualizing the Data

import matplotlib.pyplot as plt
import numpy as np

sample = next(train_dataset.create_dict_iterator())
sample_images = sample["image"]
sample_labels = sample["label"]

print("Image tensor shape:", sample_images.shape)
print("Labels:", sample_labels)
class_labels = {0: "dogs", 1: "wolves"}

plt.figure(figsize=(5, 5))
for i in range(4):
    img = sample_images[i].asnumpy()
    lbl = sample_labels[i]
    img = np.transpose(img, (1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    plt.subplot(2, 2, i + 1)
    plt.imshow(img)
    plt.title(class_labels[int(sample_labels[i].asnumpy())])
    plt.axis("off")

plt.show()

Building the ResNet50 Model

The implementation defines residual building blocks and constructs the complete ResNet architecture.

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal

weightInitializer = Normal(mean=0, sigma=0.02)
gammaInitializer = Normal(mean=1, sigma=0.02)
class BasicResidualBlock(nn.Cell):
    expansion = 1

    def __init__(self, in_channels: int, out_channels: int,
                 stride: int = 1, norm_layer: Optional[nn.Cell] = None,
                 downsample: Optional[nn.Cell] = None) -> None:
        super(BasicResidualBlock, self).__init__()
        if not norm_layer:
            self.norm = nn.BatchNorm2d(out_channels)
        else:
            self.norm = norm_layer

        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, stride=stride,
                               weight_init=weightInitializer)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, weight_init=weightInitializer)
        self.activation = nn.ReLU()
        self.downsample = downsample

    def construct(self, x):
        identity = x

        out = self.conv1(x)
        out = self.norm(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.norm(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = out + identity
        out = self.activation(out)

        return out
class BottleneckBlock(nn.Cell):
    expansion = 4

    def __init__(self, in_channels: int, out_channels: int,
                 stride: int = 1, downsample: Optional[nn.Cell] = None) -> None:
        super(BottleneckBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=1, weight_init=weightInitializer)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=stride,
                               weight_init=weightInitializer)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                               kernel_size=1, weight_init=weightInitializer)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.activation = nn.ReLU()
        self.downsample = downsample

    def construct(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.activation(out)
        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity
        out = self.activation(out)

        return out
def make_layer(prev_out_channels, block: Type[Union[BasicResidualBlock, BottleneckBlock]],
               channels: int, num_blocks: int, stride: int = 1):
    downsample = None

    if stride != 1 or prev_out_channels != channels * block.expansion:
        downsample = nn.SequentialCell([
            nn.Conv2d(prev_out_channels, channels * block.expansion,
                      kernel_size=1, stride=stride, weight_init=weightInitializer),
            nn.BatchNorm2d(channels * block.expansion, gamma_init=gammaInitializer)
        ])

    layers = []
    layers.append(block(prev_out_channels, channels, stride=stride, downsample=downsample))

    in_channels = channels * block.expansion
    for _ in range(1, num_blocks):
        layers.append(block(in_channels, channels))

    return nn.SequentialCell(layers)
from mindspore import load_checkpoint, load_param_into_net

class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[BasicResidualBlock, BottleneckBlock]],
                 layer_counts: List[int], num_classes: int, input_channels: int) -> None:
        super(ResNet, self).__init__()

        self.activation = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weightInitializer)
        self.bn = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')

        self.layer1 = make_layer(64, block, 64, layer_counts[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_counts[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_counts[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_counts[3], stride=2)

        self.avgpool = nn.AvgPool2d()
        self.flatten = nn.Flatten()
        self.fc = nn.Dense(in_channels=input_channels, out_channels=num_classes)

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.activation(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x


def _resnet(model_url: str, block: Type[Union[BasicResidualBlock, BottleneckBlock]],
            layers: List[int], num_classes: int, pretrained: bool, ckpt_path: str,
            input_channels: int):
    network = ResNet(block, layers, num_classes, input_channels)

    if pretrained:
        download(url=model_url, path=ckpt_path, replace=True)
        param_dict = load_checkpoint(ckpt_path)
        load_param_into_net(network, param_dict)

    return network


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    """Construct ResNet50 model"""
    pretrained_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
    checkpoint_file = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(pretrained_url, BottleneckBlock, [3, 4, 6, 3], num_classes,
                   pretrained, checkpoint_file, 2048)

Training with Frozen Features

When using a pre-trained model as a fixed feature extractor, freeze all layers except the final clasisfication layer by setting requires_grad = False.

import mindspore as ms
import time

network = resnet50(pretrained=True)

input_dim = network.fc.in_channels
classifier = nn.Dense(input_dim, 2)
network.fc = classifier

network.avgpool = nn.AvgPool2d(kernel_size=7)

for param in network.get_parameters():
    if param.name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False

optimizer = nn.Momentum(params=network.trainable_params(), learning_rate=learning_rate, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


def forward_fn(inputs, targets):
    predictions = network(inputs)
    loss = loss_fn(predictions, targets)
    return loss

grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)


def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    optimizer(grads)
    return loss

model = train.Model(network, loss_fn, optimizer, metrics={"Accuracy": train.Accuracy()})

Training and Evaluation Loop

train_dataset = create_dataset_canidae(train_data_path, "train")
train_steps = train_dataset.get_dataset_size()

val_dataset = create_dataset_canidae(val_data_path, "val")
val_steps = val_dataset.get_dataset_size()

checkpoint_dir = "./BestCheckpoint"
best_checkpoint_path = "./BestCheckpoint/resnet50-best-freezing-param.ckpt"

print("Starting Training Loop ...")

best_accuracy = 0

for epoch in range(num_epochs):
    epoch_losses = []
    network.set_train()

    epoch_start = time.time()

    for batch_idx, (images, labels) in enumerate(train_dataset.create_tuple_iterator()):
        labels = labels.astype(ms.int32)
        loss = train_step(images, labels)
        epoch_losses.append(loss)

    accuracy = model.eval(val_dataset)['Accuracy']

    epoch_end = time.time()
    epoch_duration = (epoch_end - epoch_start) * 1000
    step_duration = epoch_duration / train_steps

    print("-" * 20)
    print("Epoch: [%3d/%3d], Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
        epoch + 1, num_epochs, sum(epoch_losses) / len(epoch_losses), accuracy
    ))
    print("Epoch time: %5.3f ms, Per step time: %5.3f ms" % (
        epoch_duration, step_duration
    ))

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        ms.save_checkpoint(network, best_checkpoint_path)

print("=" * 80)
print(f"Best Accuracy: {best_accuracy:5.3f}, Checkpoint saved to {best_checkpoint_path}", flush=True)

Visualizing Predictions

import matplotlib.pyplot as plt
import mindspore as ms


def visualize_predictions(checkpoint_path, validation_ds):
    network = resnet50()
    network.fc = nn.Dense(network.fc.in_channels, 2)
    network.avgpool = nn.AvgPool2d(kernel_size=7)
    
    param_dict = ms.load_checkpoint(checkpoint_path)
    ms.load_param_into_net(network, param_dict)
    inference_model = train.Model(network)

    sample = next(validation_ds.create_dict_iterator())
    images = sample["image"].asnumpy()
    labels = sample["label"].asnumpy()
    class_names = {0: "dogs", 1: "wolves"}

    predictions = inference_model.predict(ms.Tensor(sample['image']))
    pred_classes = np.argmax(predictions.asnumpy(), axis=1)

    plt.figure(figsize=(5, 5))
    for i in range(4):
        plt.subplot(2, 2, i + 1)
        color = 'blue' if pred_classes[i] == labels[i] else 'red'
        plt.title(f'Prediction: {class_names[pred_classes[i]]}', color=color)
        display_img = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        display_img = std * display_img + mean
        display_img = np.clip(display_img, 0, 1)
        plt.imshow(display_img)
        plt.axis('off')

    plt.show()


visualize_predictions(best_checkpoint_path, val_dataset)

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

SBUS Signal Analysis and Communication Implementation Using STM32 with Fus Remote Controller

Overview In a recent project, I utilized the SBUS protocol with the Fus remote controller to control a vehicle's basic operations, including movement, lights, and mode switching. This article is aimed...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.