Analyzing PyTorch GPU Memory Usage with Snapshot Tools
GPU memory snapshot tools in PyTorch enable detailed analysis of memory allocation and deallocation events during model execution. These tools help diagnose common issues such as out-of-memory (OOM) errors and provide insights into memory consumption patterns.
Core Functions
PyTorch provides internal APIs for capturing memory snapshots, available in version 2.1 and later.
- Start Recording:
torch.cuda.memory._record_memory_history(max_entries=N)begins tracking memory events, withmax_entriesdefining the buffer capacity for events. - Save Snapshot:
torch.cuda.memory._dump_snapshot(filename)saves the current memory state to a file. - Stop Recording:
torch.cuda.memory._record_memory_history(enabled=None)halts the recording process.
Captured snapshots can be visualized using the PyTorch Memory Visualizer at https://pytorch.org/memory_viz.
Basic Usage Example
The following code demonstrates integrating memory snapshot capture into a traniing loop.
import torch
# Configuration
MAX_EVENTS = 100000
SNAPSHOT_PREFIX = "training_run"
# Initialize model, data, optimizer, and loss function
network = MyModel().cuda()
input_batch = torch.randn(32, 3, 224, 224, device='cuda')
target_labels = torch.randint(0, 1000, (32,), device='cuda')
optim = torch.optim.Adam(network.parameters())
loss_func = torch.nn.CrossEntropyLoss()
# Start recording memory history
torch.cuda.memory._record_memory_history(max_entries=MAX_EVENTS)
# Training loop with periodic snapshots
for epoch in range(10):
predictions = network(input_batch)
loss = loss_func(predictions, target_labels)
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
# Save a snapshot after each epoch
snapshot_file = f"{SNAPSHOT_PREFIX}_epoch_{epoch}.pickle"
torch.cuda.memory._dump_snapshot(snapshot_file)
# Stop recording
torch.cuda.memory._record_memory_history(enabled=None)
Complete Implementation for ResNet-50
This example provides a more robust implementation with logging and utility functions.
import logging
import socket
from datetime import datetime
import torch
from torchvision import models
# Logging setup
logging.basicConfig(
format='%(levelname)s:%(asctime)s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Constants
TIME_FORMAT = "%b_%d_%H_%M_%S"
MAX_MEMORY_EVENTS = 100000
def start_memory_recording() -> None:
"""Enables recording of CUDA memory events."""
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Recording skipped.")
return
logger.info("Starting memory history recording.")
torch.cuda.memory._record_memory_history(max_entries=MAX_MEMORY_EVENTS)
def stop_memory_recording() -> None:
"""Disables recording of CUDA memory events."""
if not torch.cuda.is_available():
return
logger.info("Stopping memory history recording.")
torch.cuda.memory._record_memory_history(enabled=None)
def save_memory_snapshot() -> None:
"""Exports the current memory state to a timestamped file."""
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Snapshot export skipped.")
return
host = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT)
filename = f"{host}_{timestamp}.pickle"
try:
torch.cuda.memory._dump_snapshot(filename)
logger.info(f"Snapshot saved: {filename}")
except Exception as err:
logger.error(f"Failed to save snapshot: {err}")
def train_resnet(iterations=5, device='cuda:0'):
"""Trains a ResNet-50 model while capturing memory snapshots."""
model = models.resnet50().to(device)
sample_input = torch.randn(1, 3, 224, 224, device=device)
sample_labels = torch.rand_like(model(sample_input))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
start_memory_recording()
for step in range(iterations):
output = model(sample_input)
loss = criterion(output, sample_labels)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
save_memory_snapshot()
stop_memory_recording()
if __name__ == "__main__":
train_resnet()
Enhanced Profiling with Torch Profiler
For more detailed analysis, including memory timelines correlated with operations, the PyTorch Profiler offers advanced capabilities.
import logging
import socket
from datetime import datetime
import torch
from torch.autograd.profiler import record_function
from torchvision import models
logging.basicConfig(
format='%(levelname)s:%(asctime)s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
TIME_FORMAT = "%b_%d_%H_%M_%S"
def save_profiler_output(profiler):
"""Callback to save profiler traces and memory timelines."""
host = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT)
prefix = f"{host}_{timestamp}"
profiler.export_chrome_trace(f"{prefix}.json.gz")
profiler.export_memory_timeline(f"{prefix}.html", device="cuda:0")
def profile_resnet_training(iterations=5, device='cuda:0'):
"""Profiles the training loop of a ResNet-50 model."""
model = models.resnet50().to(device)
inputs = torch.randn(1, 3, 224, 224, device=device)
labels = torch.rand_like(model(inputs))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
record_shapes=True,
profile_memory=True,
with_stack=True,
on_trace_ready=save_profiler_output,
) as prof:
for i in range(iterations):
prof.step()
with record_function("forward_pass"):
pred = model(inputs)
with record_function("backward_pass"):
loss_fn(pred, labels).backward()
with record_function("optimizer_step"):
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
# Warm-up run
profile_resnet_training(iterations=1)
# Main profiling run
profile_resnet_training()