Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Building a Convolutional Neural Network for MNIST Classification with PyTorch

Tech May 7 11

Preparing the MNIST Dataset

The MNIST dataset consists of 28×28 grayscale images of handwritten digits, split into 60,000 training samples and 10,000 test samples. We use torchvision to download and transform the data.

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

composed_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root='./mnist_data', train=True, download=True, transform=composed_transforms)
test_dataset  = datasets.MNIST(root='./mnist_data', train=False, download=True, transform=composed_transforms)

batch = 128
train_loader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch, shuffle=False)

Defining the Network Architecture

The model stacks two cnovolutional blocks followed by a linear classifier. Each convolutional block includes convolution, batch normalization, ReLU activation, and max pooling.

import torch.nn as nn

class DigitClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        feats = self.features(x)
        flattened = feats.view(feats.size(0), -1)
        out = self.classifier(flattened)
        return out

Training and Eavluation Loop

We train with cross‑entropy loss and the Adam optimizer. After every epoch, the model is evaluated on the test set to measure generalization performance.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DigitClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

epochs = 8
for epoch_idx in range(epochs):
    model.train()
    epoch_loss, correct_train = 0.0, 0
    for images, targets in train_loader:
        images, targets = images.to(device), targets.to(device)

        optimizer.zero_grad()
        predictions = model(images)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * images.size(0)
        correct_train += (predictions.argmax(1) == targets).sum().item()

    train_acc = correct_train / len(train_loader.dataset)

    model.eval()
    correct_test = 0
    with torch.no_grad():
        for images, targets in test_loader:
            images, targets = images.to(device), targets.to(device)
            preds = model(images)
            correct_test += (preds.argmax(1) == targets).sum().item()
    test_acc = correct_test / len(test_loader.dataset)

    print(f"Epoch {epoch_idx+1:02d} | Loss: {epoch_loss/len(train_loader.dataset):.4f} | "
          f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

Visualizing Predictions

After training, we can inspect how the model performs on a small batch of test images.

import matplotlib.pyplot as plt
import numpy as np

sample_images, sample_labels = next(iter(test_loader))
sample_images, sample_labels = sample_images[:16].to(device), sample_labels[:16]

model.eval()
with torch.no_grad():
    outputs = model(sample_images)
    _, pred_classes = outputs.max(1)

inverse_normalize = lambda img: (img * 0.5) + 0.5
plt.figure(figsize=(10, 5))
for idx in range(16):
    plt.subplot(2, 8, idx+1)
    plt.imshow(sample_images[idx].cpu().squeeze(), cmap='gray')
    plt.axis('off')
    plt.title(f"P:{pred_classes[idx].item()} T:{sample_labels[idx].item()}", fontsize=9)
plt.tight_layout()
plt.show()

Saving the Trained Model

torch.save(model.state_dict(), 'mnist_digit_classifier.pt')

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

SBUS Signal Analysis and Communication Implementation Using STM32 with Fus Remote Controller

Overview In a recent project, I utilized the SBUS protocol with the Fus remote controller to control a vehicle's basic operations, including movement, lights, and mode switching. This article is aimed...

Comprehensive Guide to Hive SQL Syntax and Operations

This article provides a detailed walkthrough of Hive SQL, categorizing its features and syntax for practical use. Hive SQL is segmented into the following categories: DDL Statements: Operations on...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.