Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Building a Convolutional Neural Network for MNIST Classification with PyTorch

Tech 2

Preparing the MNIST Dataset

The MNIST dataset consists of 28×28 grayscale images of handwritten digits, split into 60,000 training samples and 10,000 test samples. We use torchvision to download and transform the data.

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

composed_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root='./mnist_data', train=True, download=True, transform=composed_transforms)
test_dataset  = datasets.MNIST(root='./mnist_data', train=False, download=True, transform=composed_transforms)

batch = 128
train_loader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch, shuffle=False)

Defining the Network Architecture

The model stacks two cnovolutional blocks followed by a linear classifier. Each convolutional block includes convolution, batch normalization, ReLU activation, and max pooling.

import torch.nn as nn

class DigitClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        feats = self.features(x)
        flattened = feats.view(feats.size(0), -1)
        out = self.classifier(flattened)
        return out

Training and Eavluation Loop

We train with cross‑entropy loss and the Adam optimizer. After every epoch, the model is evaluated on the test set to measure generalization performance.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DigitClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

epochs = 8
for epoch_idx in range(epochs):
    model.train()
    epoch_loss, correct_train = 0.0, 0
    for images, targets in train_loader:
        images, targets = images.to(device), targets.to(device)

        optimizer.zero_grad()
        predictions = model(images)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * images.size(0)
        correct_train += (predictions.argmax(1) == targets).sum().item()

    train_acc = correct_train / len(train_loader.dataset)

    model.eval()
    correct_test = 0
    with torch.no_grad():
        for images, targets in test_loader:
            images, targets = images.to(device), targets.to(device)
            preds = model(images)
            correct_test += (preds.argmax(1) == targets).sum().item()
    test_acc = correct_test / len(test_loader.dataset)

    print(f"Epoch {epoch_idx+1:02d} | Loss: {epoch_loss/len(train_loader.dataset):.4f} | "
          f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

Visualizing Predictions

After training, we can inspect how the model performs on a small batch of test images.

import matplotlib.pyplot as plt
import numpy as np

sample_images, sample_labels = next(iter(test_loader))
sample_images, sample_labels = sample_images[:16].to(device), sample_labels[:16]

model.eval()
with torch.no_grad():
    outputs = model(sample_images)
    _, pred_classes = outputs.max(1)

inverse_normalize = lambda img: (img * 0.5) + 0.5
plt.figure(figsize=(10, 5))
for idx in range(16):
    plt.subplot(2, 8, idx+1)
    plt.imshow(sample_images[idx].cpu().squeeze(), cmap='gray')
    plt.axis('off')
    plt.title(f"P:{pred_classes[idx].item()} T:{sample_labels[idx].item()}", fontsize=9)
plt.tight_layout()
plt.show()

Saving the Trained Model

torch.save(model.state_dict(), 'mnist_digit_classifier.pt')

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.