Visualizing PyTorch Training with TensorBoard
Environment Setup
A working PyTorch installation is required. Verify the environment with:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
If the output includes a version string and True for CUDA, the setup is correct.
Install TensorBoard using pip:
pip install tensorboard
Once installed, the tensorboard command becomes available in the terminal. PyTorch integrates directly with TensorBoard, allowing logging of scalars, images, histgorams, graphs, and embeddings for later inspection in the browser-based UI.
Basic Workflow Example
A minimal example demonstrates the logging pipeline:
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = torchvision.models.resnet50(pretrained=False)
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
sample_images, sample_labels = next(iter(trainloader))
image_grid = torchvision.utils.make_grid(sample_images)
writer.add_image('images', image_grid, 0)
writer.add_graph(model, sample_images)
writer.close()
The SummaryWriter instance stores event files in ./runs/ by default. After executing the script, launch TensorBoard by pointing it at the log directory:
tensorboard --logdir=runs
Open http://localhost:6006/ in a browser. The dashboard will display the IMAGES and GRAPHS tabs populated by the logged data. The computation graph can be downloaded as a PNG from the GRAPHS tab for documentation purposes.
Logging Training Metrics
Scalars such as loss and accuracy can be tracked across epochs and displayed in real time. The following loop illustrates the pattern:
import numpy as np
writer = SummaryWriter()
train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []
for epoch_idx in range(num_epochs):
# ---- Training Phase ----
model.train()
cumulative_loss, cumulative_correct = 0.0, 0
for batch_inputs, batch_targets in train_loader:
batch_inputs = batch_inputs.cuda()
batch_targets = batch_targets.cuda()
predictions = model(batch_inputs)
batch_loss = criterion(predictions, batch_targets)
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
cumulative_loss += batch_loss.item() * batch_targets.size(0)
_, predicted_class = torch.max(predictions, dim=1)
cumulative_correct += (predicted_class == batch_targets).sum().item()
avg_train_loss = cumulative_loss / len(train_dataset)
avg_train_acc = cumulative_correct / len(train_dataset)
train_loss_history.append(avg_train_loss)
train_acc_history.append(avg_train_acc)
# ---- Evaluation Phase ----
model.eval()
eval_loss_sum, eval_correct_sum = 0.0, 0
with torch.no_grad():
for val_inputs, val_targets in test_loader:
val_inputs = val_inputs.cuda()
val_targets = val_targets.cuda()
val_outputs = model(val_inputs)
eval_loss_sum += criterion(val_outputs, val_targets).item() * val_targets.size(0)
_, val_preds = torch.max(val_outputs, dim=1)
eval_correct_sum += (val_preds == val_targets).sum().item()
avg_test_loss = eval_loss_sum / len(test_dataset)
avg_test_acc = eval_correct_sum / len(test_dataset)
test_loss_history.append(avg_test_loss)
test_acc_history.append(avg_test_acc)
# ---- Log Scalars ----
writer.add_scalar('loss/train', avg_train_loss, epoch_idx)
writer.add_scalar('loss/test', avg_test_loss, epoch_idx)
writer.add_scalar('accuracy/train', avg_train_acc, epoch_idx)
writer.add_scalar('accuracy/test', avg_test_acc, epoch_idx)
# ---- Log Sample Images ----
writer.add_images('validation_samples', val_inputs, epoch_idx, dataformats='NCHW')
print(f'Epoch {epoch_idx+1} | Train Loss: {avg_train_loss:.4f} Acc: {avg_train_acc:.4f} | Test Loss: {avg_test_loss:.4f} Acc: {avg_test_acc:.4f}')
Each call to add_scalar stores a named time series. The forward slash in the tag (e.g., loss/train) creates hierarchical groupings in the TensorBoard interface, making it easier to compare related metrics side by side. The add_images method visualizes actual input batches, helping spot data issues or monitor model focus over time.
Start TensorBoard and navigate to the SCALARS tab to observe real-time updating plots of loss and accuracy. The images appear under the IMAGES tab. Use the smoothing slider in the scalar view to reduce noise and reveal underlying trends in the metrics.