Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

AlexNet Convolutional Neural Network for FashionMNIST Classification

Tech 1

Project File Structure

This project uses three core Python files to build and evaluate a FashionMNIST image classifier using a customized AlexNet architecture:

  1. model.py: Defines the adapted AlexNet model for grayscale 28x28 FashionMNIST images
  2. model_train.py: Manages data loading, training, validation, and training metric visualization
  3. model_test.py: Handles model evaluation, confusion matrix generation, and sample prediction display

Custom AlexNet Model (model.py)

import torch
from torch import nn
from torchsummary import summary
import torch.nn.functional as F

class FashionAlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.relu = nn.ReLU()
        # First convolutional block
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        # Second convolutional block
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        # Third, fourth and fifth convolutional blocks
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        # Flatten layer for fully connected layers
        self.flatten = nn.Flatten()
        # Fully connected layers with dropout regularization
        self.fc1 = nn.Linear(6 * 6 * 256, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))
        x = self.pool3(x)

        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = F.dropout(x, p=0.5)
        x = self.relu(self.fc2(x))
        x = F.dropout(x, p=0.5)
        x = self.fc3(x)

        return x

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = FashionAlexNet().to(device)
    print(summary(model, input_size=(1, 227, 227)))

Training and Validasion Pipeline (model_train.py)

from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import FashionAlexNet
from tqdm import tqdm

def get_fashion_mnist_loaders():
    transform_pipeline = transforms.Compose([
        transforms.Resize((227, 227)),
        transforms.ToTensor()
    ])

    full_train_dataset = FashionMNIST(
        root="./data",
        train=True,
        transform=transform_pipeline,
        download=True
    )

    train_size = int(0.8 * len(full_train_dataset))
    val_size = len(full_train_dataset) - train_size
    train_dataset, val_dataset = data.random_split(full_train_dataset, [train_size, val_size])

    train_loader = data.DataLoader(
        dataset=train_dataset,
        batch_size=16,
        shuffle=True,
        num_workers=2
    )

    val_loader = data.DataLoader(
        dataset=val_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=2
    )

    return train_loader, val_loader

def train_evaluate_model(model, train_loader, val_loader, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on device: {device}")

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()

    model = model.to(device)

    best_weights = copy.deepcopy(model.state_dict())
    best_accuracy = 0.0

    train_loss_history = []
    val_loss_history = []
    train_acc_history = []
    val_acc_history = []

    start_time = time.time()

    for epoch in range(1, num_epochs + 1):
        epoch_start = time.time()
        print("-" * 20)
        print(f"Epoch {epoch}/{num_epochs}")

        model.train()
        train_running_loss = 0.0
        train_running_correct = 0
        train_total = 0

        train_progress = tqdm(train_loader, desc=f"Training Epoch {epoch}", unit="batch")
        for inputs, labels in train_progress:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = loss_fn(outputs, labels)

            loss.backward()
            optimizer.step()

            train_running_loss += loss.item() * inputs.size(0)
            train_running_correct += torch.sum(preds == labels.data)
            train_total += inputs.size(0)

        epoch_train_loss = train_running_loss / train_total
        epoch_train_acc = train_running_correct.double().item() / train_total
        train_loss_history.append(epoch_train_loss)
        train_acc_history.append(epoch_train_acc)

        model.eval()
        val_running_loss = 0.0
        val_running_correct = 0
        val_total = 0

        val_progress = tqdm(val_loader, desc=f"Validation Epoch {epoch}", unit="batch")
        with torch.no_grad():
            for inputs, labels in val_progress:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = loss_fn(outputs, labels)

                val_running_loss += loss.item() * inputs.size(0)
                val_running_correct += torch.sum(preds == labels.data)
                val_total += inputs.size(0)

        epoch_val_loss = val_running_loss / val_total
        epoch_val_acc = val_running_correct.double().item() / val_total
        val_loss_history.append(epoch_val_loss)
        val_acc_history.append(epoch_val_acc)

        print(f"Training Loss: {epoch_train_loss:.4f} | Training Accuracy: {epoch_train_acc:.4f}")
        print(f"Validation Loss: {epoch_val_loss:.4f} | Validation Accuracy: {epoch_val_acc:.4f}")
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch} completed in {epoch_time // 60:.0f}m {epoch_time % 60:.0f}s")

        if epoch_val_acc > best_accuracy:
            best_accuracy = epoch_val_acc
            best_weights = copy.deepcopy(model.state_dict())

    torch.save(best_weights, "./weight/best_model.pth")
    if epoch == num_epochs:
        torch.save(model.state_dict(), f"./weight/{num_epochs}_epoch_model.pth")

    training_history = pd.DataFrame({
        "epoch": range(1, num_epochs +1),
        "train_loss": train_loss_history,
        "val_loss": val_loss_history,
        "train_accuracy": train_acc_history,
        "val_accuracy": val_acc_history
    })

    total_time = time.time() - start_time
    print(f"Total training time: {total_time //60:.0f}m {total_time %60:.0f}s")

    return training_history

def plot_training_metrics(training_history):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(training_history["epoch"], training_history["train_loss"], "ro-", label="Training Loss")
    plt.plot(training_history["epoch"], training_history["val_loss"], "bs-", label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig("./result_picture/training_loss.png", bbox_inches="tight")

    plt.subplot(1, 2, 2)
    plt.plot(training_history["epoch"], training_history["train_accuracy"], "ro-", label="Training Accuracy")
    plt.plot(training_history["epoch"], training_history["val_accuracy"], "bs-", label="Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.savefig("./result_picture/training_accuracy.png", bbox_inches="tight")
    plt.show()

if __name__ == "__main__":
    model = FashionAlexNet()
    train_loader, val_loader = get_fashion_mnist_loaders()
    training_history = train_evaluate_model(model, train_loader, val_loader, num_epochs=10)
    plot_training_metrics(training_history)

Model Evaluation and Testing (model_test.py)

import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import FashionAlexNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

def get_test_loader():
    transform_pipeline = transforms.Compose([
        transforms.Resize((227, 227)),
        transforms.ToTensor()
    ])

    test_dataset = FashionMNIST(
        root="./data",
        train=False,
        transform=transform_pipeline,
        download=True
    )

    test_loader = data.DataLoader(
        dataset=test_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0
    )

    return test_loader

def evaluate_test_model(model, test_loader):
    if not model:
        raise ValueError("No model provided for evaluation")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    test_correct = 0
    test_total = 0
    all_predictions = []
    all_true_labels = []

    start_time = time.time()

    with torch.no_grad():
        test_progress = tqdm(test_loader, desc="Running Tests", unit="image")
        for test_inputs, test_labels in test_progress:
            test_inputs = test_inputs.to(device)
            test_labels = test_labels.to(device)

            outputs = model(test_inputs)
            _, preds = torch.max(outputs, 1)

            all_predictions.extend(preds.tolist())
            all_true_labels.extend(test_labels.tolist())

            test_correct += torch.sum(preds == test_labels.data)
            test_total += test_inputs.size(0)

    test_accuracy = test_correct.double().item() / test_total
    print(f"Total test samples: {test_total}")
    print(f"Correct predictions: {test_correct}")
    print(f"Test accuracy: {test_accuracy:.4f}")

    total_test_time = time.time() - start_time
    print(f"Total test time: {total_test_time //60:.0f}m {total_test_time %60:.0f}s")

    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    conf_mat = confusion_matrix(all_true_labels, all_predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_mat, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix for FashionMNIST")
    plt.savefig("./result_picture/confusion_matrix.png", bbox_inches="tight")
    plt.show()

    print("\nSample Prediction:")
    with torch.no_grad():
        for sample_input, sample_label in test_loader:
            sample_input = sample_input.to(device)
            sample_label = sample_label.to(device)
            sample_output = model(sample_input)
            _, sample_pred = torch.max(sample_output, 1)
            print(f"Predicted: {class_names[sample_pred.item()]} | True: {class_names[sample_label.item()]}")
            break

if __name__ == "__main__":
    model = FashionAlexNet()
    model.load_state_dict(torch.load("./weight/best_model.pth"))
    test_loader = get_test_loader()
    evaluate_test_model(model, test_loader)

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.