Transfer Learning with ResNet50 for Dogs and Wolves Classification
Transfer Learning with ResNet50
When training deep learning models for specific tasks, collecting large amounts of training data is often impractical. Instead of training from scratch, a common approach is to use a pre-trained model—typically one trained on a large foundational dataset—and adapt it for the target application.
This tutorial demonstrates transfer learning by fine-tuning a ResNet50 model pre-trained on ImageNet to classify images of wolves and dogs.
Dataset Preparation
Downloading the Dataset
The Canidae dataset contains images extracted from ImageNet, with approximately 120 training images and 30 validation images per category.
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)
Loading the Dataset
Using MindSpore's ImageFolderDataset interface to load the dataset with appropriate augmentation for training and inference.
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
batch_size = 18
image_size = 224
num_epochs = 5
learning_rate = 0.001
momentum = 0.9
num_workers = 4
train_data_path = "./datasets-Canidae/data/Canidae/train/"
val_data_path = "./datasets-Canidae/data/Canidae/val/"
def create_dataset_canidae(dataset_path, usage):
"""Load and preprocess dataset"""
data_set = ds.ImageFolderDataset(dataset_path,
num_parallel_workers=num_workers,
shuffle=True)
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
scale_offset = 32
if usage == "train":
transforms = [
vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transforms = [
vision.Decode(),
vision.Resize(image_size + scale_offset),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
data_set = data_set.map(
operations=transforms,
input_columns='image',
num_parallel_workers=num_workers)
data_set = data_set.batch(batch_size)
return data_set
train_dataset = create_dataset_canidae(train_data_path, "train")
train_steps = train_dataset.get_dataset_size()
val_dataset = create_dataset_canidae(val_data_path, "val")
val_steps = val_dataset.get_dataset_size()
Visualizing the Data
import matplotlib.pyplot as plt
import numpy as np
sample = next(train_dataset.create_dict_iterator())
sample_images = sample["image"]
sample_labels = sample["label"]
print("Image tensor shape:", sample_images.shape)
print("Labels:", sample_labels)
class_labels = {0: "dogs", 1: "wolves"}
plt.figure(figsize=(5, 5))
for i in range(4):
img = sample_images[i].asnumpy()
lbl = sample_labels[i]
img = np.transpose(img, (1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = std * img + mean
img = np.clip(img, 0, 1)
plt.subplot(2, 2, i + 1)
plt.imshow(img)
plt.title(class_labels[int(sample_labels[i].asnumpy())])
plt.axis("off")
plt.show()
Building the ResNet50 Model
The implementation defines residual building blocks and constructs the complete ResNet architecture.
from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal
weightInitializer = Normal(mean=0, sigma=0.02)
gammaInitializer = Normal(mean=1, sigma=0.02)
class BasicResidualBlock(nn.Cell):
expansion = 1
def __init__(self, in_channels: int, out_channels: int,
stride: int = 1, norm_layer: Optional[nn.Cell] = None,
downsample: Optional[nn.Cell] = None) -> None:
super(BasicResidualBlock, self).__init__()
if not norm_layer:
self.norm = nn.BatchNorm2d(out_channels)
else:
self.norm = norm_layer
self.conv1 = nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride,
weight_init=weightInitializer)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, weight_init=weightInitializer)
self.activation = nn.ReLU()
self.downsample = downsample
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.norm(out)
out = self.activation(out)
out = self.conv2(out)
out = self.norm(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out + identity
out = self.activation(out)
return out
class BottleneckBlock(nn.Cell):
expansion = 4
def __init__(self, in_channels: int, out_channels: int,
stride: int = 1, downsample: Optional[nn.Cell] = None) -> None:
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels,
kernel_size=1, weight_init=weightInitializer)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, stride=stride,
weight_init=weightInitializer)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
kernel_size=1, weight_init=weightInitializer)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.activation = nn.ReLU()
self.downsample = downsample
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.activation(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out + identity
out = self.activation(out)
return out
def make_layer(prev_out_channels, block: Type[Union[BasicResidualBlock, BottleneckBlock]],
channels: int, num_blocks: int, stride: int = 1):
downsample = None
if stride != 1 or prev_out_channels != channels * block.expansion:
downsample = nn.SequentialCell([
nn.Conv2d(prev_out_channels, channels * block.expansion,
kernel_size=1, stride=stride, weight_init=weightInitializer),
nn.BatchNorm2d(channels * block.expansion, gamma_init=gammaInitializer)
])
layers = []
layers.append(block(prev_out_channels, channels, stride=stride, downsample=downsample))
in_channels = channels * block.expansion
for _ in range(1, num_blocks):
layers.append(block(in_channels, channels))
return nn.SequentialCell(layers)
from mindspore import load_checkpoint, load_param_into_net
class ResNet(nn.Cell):
def __init__(self, block: Type[Union[BasicResidualBlock, BottleneckBlock]],
layer_counts: List[int], num_classes: int, input_channels: int) -> None:
super(ResNet, self).__init__()
self.activation = nn.ReLU()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weightInitializer)
self.bn = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = make_layer(64, block, 64, layer_counts[0])
self.layer2 = make_layer(64 * block.expansion, block, 128, layer_counts[1], stride=2)
self.layer3 = make_layer(128 * block.expansion, block, 256, layer_counts[2], stride=2)
self.layer4 = make_layer(256 * block.expansion, block, 512, layer_counts[3], stride=2)
self.avgpool = nn.AvgPool2d()
self.flatten = nn.Flatten()
self.fc = nn.Dense(in_channels=input_channels, out_channels=num_classes)
def construct(self, x):
x = self.conv1(x)
x = self.bn(x)
x = self.activation(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
def _resnet(model_url: str, block: Type[Union[BasicResidualBlock, BottleneckBlock]],
layers: List[int], num_classes: int, pretrained: bool, ckpt_path: str,
input_channels: int):
network = ResNet(block, layers, num_classes, input_channels)
if pretrained:
download(url=model_url, path=ckpt_path, replace=True)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
return network
def resnet50(num_classes: int = 1000, pretrained: bool = False):
"""Construct ResNet50 model"""
pretrained_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
checkpoint_file = "./LoadPretrainedModel/resnet50_224_new.ckpt"
return _resnet(pretrained_url, BottleneckBlock, [3, 4, 6, 3], num_classes,
pretrained, checkpoint_file, 2048)
Training with Frozen Features
When using a pre-trained model as a fixed feature extractor, freeze all layers except the final clasisfication layer by setting requires_grad = False.
import mindspore as ms
import time
network = resnet50(pretrained=True)
input_dim = network.fc.in_channels
classifier = nn.Dense(input_dim, 2)
network.fc = classifier
network.avgpool = nn.AvgPool2d(kernel_size=7)
for param in network.get_parameters():
if param.name not in ["fc.weight", "fc.bias"]:
param.requires_grad = False
optimizer = nn.Momentum(params=network.trainable_params(), learning_rate=learning_rate, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def forward_fn(inputs, targets):
predictions = network(inputs)
loss = loss_fn(predictions, targets)
return loss
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)
def train_step(inputs, targets):
loss, grads = grad_fn(inputs, targets)
optimizer(grads)
return loss
model = train.Model(network, loss_fn, optimizer, metrics={"Accuracy": train.Accuracy()})
Training and Evaluation Loop
train_dataset = create_dataset_canidae(train_data_path, "train")
train_steps = train_dataset.get_dataset_size()
val_dataset = create_dataset_canidae(val_data_path, "val")
val_steps = val_dataset.get_dataset_size()
checkpoint_dir = "./BestCheckpoint"
best_checkpoint_path = "./BestCheckpoint/resnet50-best-freezing-param.ckpt"
print("Starting Training Loop ...")
best_accuracy = 0
for epoch in range(num_epochs):
epoch_losses = []
network.set_train()
epoch_start = time.time()
for batch_idx, (images, labels) in enumerate(train_dataset.create_tuple_iterator()):
labels = labels.astype(ms.int32)
loss = train_step(images, labels)
epoch_losses.append(loss)
accuracy = model.eval(val_dataset)['Accuracy']
epoch_end = time.time()
epoch_duration = (epoch_end - epoch_start) * 1000
step_duration = epoch_duration / train_steps
print("-" * 20)
print("Epoch: [%3d/%3d], Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
epoch + 1, num_epochs, sum(epoch_losses) / len(epoch_losses), accuracy
))
print("Epoch time: %5.3f ms, Per step time: %5.3f ms" % (
epoch_duration, step_duration
))
if accuracy > best_accuracy:
best_accuracy = accuracy
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
ms.save_checkpoint(network, best_checkpoint_path)
print("=" * 80)
print(f"Best Accuracy: {best_accuracy:5.3f}, Checkpoint saved to {best_checkpoint_path}", flush=True)
Visualizing Predictions
import matplotlib.pyplot as plt
import mindspore as ms
def visualize_predictions(checkpoint_path, validation_ds):
network = resnet50()
network.fc = nn.Dense(network.fc.in_channels, 2)
network.avgpool = nn.AvgPool2d(kernel_size=7)
param_dict = ms.load_checkpoint(checkpoint_path)
ms.load_param_into_net(network, param_dict)
inference_model = train.Model(network)
sample = next(validation_ds.create_dict_iterator())
images = sample["image"].asnumpy()
labels = sample["label"].asnumpy()
class_names = {0: "dogs", 1: "wolves"}
predictions = inference_model.predict(ms.Tensor(sample['image']))
pred_classes = np.argmax(predictions.asnumpy(), axis=1)
plt.figure(figsize=(5, 5))
for i in range(4):
plt.subplot(2, 2, i + 1)
color = 'blue' if pred_classes[i] == labels[i] else 'red'
plt.title(f'Prediction: {class_names[pred_classes[i]]}', color=color)
display_img = np.transpose(images[i], (1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
display_img = std * display_img + mean
display_img = np.clip(display_img, 0, 1)
plt.imshow(display_img)
plt.axis('off')
plt.show()
visualize_predictions(best_checkpoint_path, val_dataset)