Building GoogLeNet and ResNet Architectures in PyTorch
The Inception architecture processes feature maps through parallel convolutional branches with varying receptive fields. In PyTorch, an InceptionModule can be constructed by defining four distinct pathways: a standalone 1×1 convolution, a 5×5 convolution preceded by a bottleneck layer, a cascade of two 3×3 convolutions with intermediate channels, and an average-pooling branch projected through a 1×1 filter. The branch outputs are concatenated along the channel dimension.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
batch = 64
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(root='./mnist_data', train=True, download=True, transform=preprocess)
train_loader = DataLoader(train_data, batch_size=batch, shuffle=True)
test_data = datasets.MNIST(root='./mnist_data', train=False, download=True, transform=preprocess)
test_loader = DataLoader(test_data, batch_size=batch, shuffle=False)
class InceptionModule(nn.Module):
def __init__(self, in_planes):
super().__init__()
self.conv_1x1 = nn.Conv2d(in_planes, 12, kernel_size=1)
self.conv_5x5 = nn.Sequential(
nn.Conv2d(in_planes, 8, kernel_size=1),
nn.Conv2d(8, 16, kernel_size=5, padding=2)
)
self.conv_3x3 = nn.Sequential(
nn.Conv2d(in_planes, 8, kernel_size=1),
nn.Conv2d(8, 16, kernel_size=3, padding=1),
nn.Conv2d(16, 16, kernel_size=3, padding=1)
)
self.pool_proj = nn.Conv2d(in_planes, 16, kernel_size=1)
def forward(self, x):
y1 = self.conv_1x1(x)
y2 = self.conv_5x5(x)
y3 = self.conv_3x3(x)
y4 = F.avg_pool2d(x, 3, stride=1, padding=1)
y4 = self.pool_proj(y4)
return torch.cat([y1, y2, y3, y4], dim=1)
class GoogLeNet(nn.Module):
def __init__(self):
super().__init__()
self.stem = nn.Conv2d(1, 8, kernel_size=5)
self.incept1 = InceptionModule(8)
self.mid = nn.Conv2d(60, 16, kernel_size=5)
self.incept2 = InceptionModule(16)
self.fc = nn.Linear(960, 10)
def forward(self, x):
b = x.size(0)
x = F.max_pool2d(F.relu(self.stem(x)), 2)
x = self.incept1(x)
x = F.max_pool2d(F.relu(self.mid(x)), 2)
x = self.incept2(x)
x = x.view(b, -1)
return self.fc(x)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = GoogLeNet().to(device)
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
def fit_epoch(epoch_idx):
total_loss = 0.0
for batch_idx, (imgs, lbls) in enumerate(train_loader):
imgs, lbls = imgs.to(device), lbls.to(device)
opt.zero_grad()
preds = net(imgs)
loss = loss_fn(preds, lbls)
loss.backward()
opt.step()
total_loss += loss.item()
if batch_idx % 200 == 199:
print(f'Epoch {epoch_idx+1} | Batch {batch_idx+1} | Avg Loss: {total_loss/200:.3f}')
total_loss = 0.0
def assess():
ok = 0
n = 0
with torch.no_grad():
for imgs, lbls in test_loader:
imgs, lbls = imgs.to(device), lbls.to(device)
outs = net(imgs)
_, guess = torch.max(outs, 1)
n += lbls.size(0)
ok += (guess == lbls).sum().item()
print(f'Test Accuracy: {100 * ok / n:.2f}%')
for e in range(80):
fit_epoch(e)
if e % 10 == 0:
assess()
The GoogLeNet class begins with a stem convolution and max-pooling, passes the result through the first inception module, applies another convolution and downsampling stage, then feeds the tensor into a second inception module before flattening and classification.
Residual Networks reformulate stacked layers as residual functions with respect to their inputs. A SkipBlock encapsulates a small convolutional stack whose output is added back to the block's input, which helps mitigate vanishing gradients as depth increases.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
batch = 64
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(root='./mnist_data', train=True, download=True, transform=preprocess)
train_loader = DataLoader(train_data, batch_size=batch, shuffle=True)
test_data = datasets.MNIST(root='./mnist_data', train=False, download=True, transform=preprocess)
test_loader = DataLoader(test_data, batch_size=batch, shuffle=False)
class SkipBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.res_path = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(dim, dim, kernel_size=3, padding=1)
)
def forward(self, x):
return F.relu(x + self.res_path(x))
class ResNetMini(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(1, 12, kernel_size=5)
self.skip1 = SkipBlock(12)
self.c2 = nn.Conv2d(12, 24, kernel_size=5)
self.skip2 = SkipBlock(24)
self.head = nn.Linear(384, 10)
def forward(self, x):
b = x.size(0)
x = F.max_pool2d(F.relu(self.c1(x)), 2)
x = self.skip1(x)
x = F.max_pool2d(F.relu(self.c2(x)), 2)
x = self.skip2(x)
x = x.view(b, -1)
return self.head(x)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetMini().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
def fit(epoch_id):
cumulative = 0.0
for step, (samples, targets) in enumerate(train_loader):
samples, targets = samples.to(device), targets.to(device)
optimizer.zero_grad()
logits = model(samples)
batch_loss = criterion(logits, targets)
batch_loss.backward()
optimizer.step()
cumulative += batch_loss.item()
if step % 200 == 199:
print(f'[{epoch_id+1}, {step+1}] loss: {cumulative/200:.3f}')
cumulative = 0.0
def evaluate():
hit = 0
count = 0
with torch.no_grad():
for samples, targets in test_loader:
samples, targets = samples.to(device), targets.to(device)
logits = model(samples)
_, forecast = torch.max(logits.data, 1)
count += targets.size(0)
hit += (forecast == targets).sum().item()
print(f'Accuracy: {100 * hit / count:.0f}%')
if __name__ == '__main__':
for epoch in range(100):
fit(epoch)
if epoch % 10 == 0:
evaluate()
The ResNetMini network places skip blocks after pooled convolutional stages. During the forward pass, each block computes a residual correction that is fused with the incoming activation map via element-wise addition followed by a nonlinearity.