AlexNet Convolutional Neural Network for FashionMNIST Classification
Project File Structure
This project uses three core Python files to build and evaluate a FashionMNIST image classifier using a customized AlexNet architecture:
model.py: Defines the adapted AlexNet model for grayscale 28x28 FashionMNIST imagesmodel_train.py: Manages data loading, training, validation, and training metric visualizationmodel_test.py: Handles model evaluation, confusion matrix generation, and sample prediction display
Custom AlexNet Model (model.py)
import torch
from torch import nn
from torchsummary import summary
import torch.nn.functional as F
class FashionAlexNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.relu = nn.ReLU()
# First convolutional block
self.conv1 = nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4)
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
# Second convolutional block
self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2)
self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
# Third, fourth and fifth convolutional blocks
self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
# Flatten layer for fully connected layers
self.flatten = nn.Flatten()
# Fully connected layers with dropout regularization
self.fc1 = nn.Linear(6 * 6 * 256, 4096)
self.fc2 = nn.Linear(4096, 4096)
self.fc3 = nn.Linear(4096, num_classes)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.pool1(x)
x = self.relu(self.conv2(x))
x = self.pool2(x)
x = self.relu(self.conv3(x))
x = self.relu(self.conv4(x))
x = self.relu(self.conv5(x))
x = self.pool3(x)
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = F.dropout(x, p=0.5)
x = self.relu(self.fc2(x))
x = F.dropout(x, p=0.5)
x = self.fc3(x)
return x
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FashionAlexNet().to(device)
print(summary(model, input_size=(1, 227, 227)))
Training and Validasion Pipeline (model_train.py)
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import FashionAlexNet
from tqdm import tqdm
def get_fashion_mnist_loaders():
transform_pipeline = transforms.Compose([
transforms.Resize((227, 227)),
transforms.ToTensor()
])
full_train_dataset = FashionMNIST(
root="./data",
train=True,
transform=transform_pipeline,
download=True
)
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = data.random_split(full_train_dataset, [train_size, val_size])
train_loader = data.DataLoader(
dataset=train_dataset,
batch_size=16,
shuffle=True,
num_workers=2
)
val_loader = data.DataLoader(
dataset=val_dataset,
batch_size=16,
shuffle=False,
num_workers=2
)
return train_loader, val_loader
def train_evaluate_model(model, train_loader, val_loader, num_epochs=10):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on device: {device}")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
model = model.to(device)
best_weights = copy.deepcopy(model.state_dict())
best_accuracy = 0.0
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []
start_time = time.time()
for epoch in range(1, num_epochs + 1):
epoch_start = time.time()
print("-" * 20)
print(f"Epoch {epoch}/{num_epochs}")
model.train()
train_running_loss = 0.0
train_running_correct = 0
train_total = 0
train_progress = tqdm(train_loader, desc=f"Training Epoch {epoch}", unit="batch")
for inputs, labels in train_progress:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
train_running_loss += loss.item() * inputs.size(0)
train_running_correct += torch.sum(preds == labels.data)
train_total += inputs.size(0)
epoch_train_loss = train_running_loss / train_total
epoch_train_acc = train_running_correct.double().item() / train_total
train_loss_history.append(epoch_train_loss)
train_acc_history.append(epoch_train_acc)
model.eval()
val_running_loss = 0.0
val_running_correct = 0
val_total = 0
val_progress = tqdm(val_loader, desc=f"Validation Epoch {epoch}", unit="batch")
with torch.no_grad():
for inputs, labels in val_progress:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = loss_fn(outputs, labels)
val_running_loss += loss.item() * inputs.size(0)
val_running_correct += torch.sum(preds == labels.data)
val_total += inputs.size(0)
epoch_val_loss = val_running_loss / val_total
epoch_val_acc = val_running_correct.double().item() / val_total
val_loss_history.append(epoch_val_loss)
val_acc_history.append(epoch_val_acc)
print(f"Training Loss: {epoch_train_loss:.4f} | Training Accuracy: {epoch_train_acc:.4f}")
print(f"Validation Loss: {epoch_val_loss:.4f} | Validation Accuracy: {epoch_val_acc:.4f}")
epoch_time = time.time() - epoch_start
print(f"Epoch {epoch} completed in {epoch_time // 60:.0f}m {epoch_time % 60:.0f}s")
if epoch_val_acc > best_accuracy:
best_accuracy = epoch_val_acc
best_weights = copy.deepcopy(model.state_dict())
torch.save(best_weights, "./weight/best_model.pth")
if epoch == num_epochs:
torch.save(model.state_dict(), f"./weight/{num_epochs}_epoch_model.pth")
training_history = pd.DataFrame({
"epoch": range(1, num_epochs +1),
"train_loss": train_loss_history,
"val_loss": val_loss_history,
"train_accuracy": train_acc_history,
"val_accuracy": val_acc_history
})
total_time = time.time() - start_time
print(f"Total training time: {total_time //60:.0f}m {total_time %60:.0f}s")
return training_history
def plot_training_metrics(training_history):
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(training_history["epoch"], training_history["train_loss"], "ro-", label="Training Loss")
plt.plot(training_history["epoch"], training_history["val_loss"], "bs-", label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig("./result_picture/training_loss.png", bbox_inches="tight")
plt.subplot(1, 2, 2)
plt.plot(training_history["epoch"], training_history["train_accuracy"], "ro-", label="Training Accuracy")
plt.plot(training_history["epoch"], training_history["val_accuracy"], "bs-", label="Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig("./result_picture/training_accuracy.png", bbox_inches="tight")
plt.show()
if __name__ == "__main__":
model = FashionAlexNet()
train_loader, val_loader = get_fashion_mnist_loaders()
training_history = train_evaluate_model(model, train_loader, val_loader, num_epochs=10)
plot_training_metrics(training_history)
Model Evaluation and Testing (model_test.py)
import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import FashionAlexNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
def get_test_loader():
transform_pipeline = transforms.Compose([
transforms.Resize((227, 227)),
transforms.ToTensor()
])
test_dataset = FashionMNIST(
root="./data",
train=False,
transform=transform_pipeline,
download=True
)
test_loader = data.DataLoader(
dataset=test_dataset,
batch_size=1,
shuffle=True,
num_workers=0
)
return test_loader
def evaluate_test_model(model, test_loader):
if not model:
raise ValueError("No model provided for evaluation")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
test_correct = 0
test_total = 0
all_predictions = []
all_true_labels = []
start_time = time.time()
with torch.no_grad():
test_progress = tqdm(test_loader, desc="Running Tests", unit="image")
for test_inputs, test_labels in test_progress:
test_inputs = test_inputs.to(device)
test_labels = test_labels.to(device)
outputs = model(test_inputs)
_, preds = torch.max(outputs, 1)
all_predictions.extend(preds.tolist())
all_true_labels.extend(test_labels.tolist())
test_correct += torch.sum(preds == test_labels.data)
test_total += test_inputs.size(0)
test_accuracy = test_correct.double().item() / test_total
print(f"Total test samples: {test_total}")
print(f"Correct predictions: {test_correct}")
print(f"Test accuracy: {test_accuracy:.4f}")
total_test_time = time.time() - start_time
print(f"Total test time: {total_test_time //60:.0f}m {total_test_time %60:.0f}s")
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
conf_mat = confusion_matrix(all_true_labels, all_predictions)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_mat, annot=True, fmt="d", cmap="Blues",
xticklabels=class_names,
yticklabels=class_names)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix for FashionMNIST")
plt.savefig("./result_picture/confusion_matrix.png", bbox_inches="tight")
plt.show()
print("\nSample Prediction:")
with torch.no_grad():
for sample_input, sample_label in test_loader:
sample_input = sample_input.to(device)
sample_label = sample_label.to(device)
sample_output = model(sample_input)
_, sample_pred = torch.max(sample_output, 1)
print(f"Predicted: {class_names[sample_pred.item()]} | True: {class_names[sample_label.item()]}")
break
if __name__ == "__main__":
model = FashionAlexNet()
model.load_state_dict(torch.load("./weight/best_model.pth"))
test_loader = get_test_loader()
evaluate_test_model(model, test_loader)