Building a Convolutional Neural Network for MNIST Classification with PyTorch
Preparing the MNIST Dataset
The MNIST dataset consists of 28×28 grayscale images of handwritten digits, split into 60,000 training samples and 10,000 test samples. We use torchvision to download and transform the data.
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
composed_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./mnist_data', train=True, download=True, transform=composed_transforms)
test_dataset = datasets.MNIST(root='./mnist_data', train=False, download=True, transform=composed_transforms)
batch = 128
train_loader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch, shuffle=False)
Defining the Network Architecture
The model stacks two cnovolutional blocks followed by a linear classifier. Each convolutional block includes convolution, batch normalization, ReLU activation, and max pooling.
import torch.nn as nn
class DigitClassifier(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(128, num_classes)
)
def forward(self, x):
feats = self.features(x)
flattened = feats.view(feats.size(0), -1)
out = self.classifier(flattened)
return out
Training and Eavluation Loop
We train with cross‑entropy loss and the Adam optimizer. After every epoch, the model is evaluated on the test set to measure generalization performance.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DigitClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
epochs = 8
for epoch_idx in range(epochs):
model.train()
epoch_loss, correct_train = 0.0, 0
for images, targets in train_loader:
images, targets = images.to(device), targets.to(device)
optimizer.zero_grad()
predictions = model(images)
loss = criterion(predictions, targets)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * images.size(0)
correct_train += (predictions.argmax(1) == targets).sum().item()
train_acc = correct_train / len(train_loader.dataset)
model.eval()
correct_test = 0
with torch.no_grad():
for images, targets in test_loader:
images, targets = images.to(device), targets.to(device)
preds = model(images)
correct_test += (preds.argmax(1) == targets).sum().item()
test_acc = correct_test / len(test_loader.dataset)
print(f"Epoch {epoch_idx+1:02d} | Loss: {epoch_loss/len(train_loader.dataset):.4f} | "
f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")
Visualizing Predictions
After training, we can inspect how the model performs on a small batch of test images.
import matplotlib.pyplot as plt
import numpy as np
sample_images, sample_labels = next(iter(test_loader))
sample_images, sample_labels = sample_images[:16].to(device), sample_labels[:16]
model.eval()
with torch.no_grad():
outputs = model(sample_images)
_, pred_classes = outputs.max(1)
inverse_normalize = lambda img: (img * 0.5) + 0.5
plt.figure(figsize=(10, 5))
for idx in range(16):
plt.subplot(2, 8, idx+1)
plt.imshow(sample_images[idx].cpu().squeeze(), cmap='gray')
plt.axis('off')
plt.title(f"P:{pred_classes[idx].item()} T:{sample_labels[idx].item()}", fontsize=9)
plt.tight_layout()
plt.show()
Saving the Trained Model
torch.save(model.state_dict(), 'mnist_digit_classifier.pt')