Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Analyzing PyTorch GPU Memory Usage with Snapshot Tools

Tech 2

GPU memory snapshot tools in PyTorch enable detailed analysis of memory allocation and deallocation events during model execution. These tools help diagnose common issues such as out-of-memory (OOM) errors and provide insights into memory consumption patterns.

Core Functions

PyTorch provides internal APIs for capturing memory snapshots, available in version 2.1 and later.

  • Start Recording: torch.cuda.memory._record_memory_history(max_entries=N) begins tracking memory events, with max_entries defining the buffer capacity for events.
  • Save Snapshot: torch.cuda.memory._dump_snapshot(filename) saves the current memory state to a file.
  • Stop Recording: torch.cuda.memory._record_memory_history(enabled=None) halts the recording process.

Captured snapshots can be visualized using the PyTorch Memory Visualizer at https://pytorch.org/memory_viz.

Basic Usage Example

The following code demonstrates integrating memory snapshot capture into a traniing loop.

import torch

# Configuration
MAX_EVENTS = 100000
SNAPSHOT_PREFIX = "training_run"

# Initialize model, data, optimizer, and loss function
network = MyModel().cuda()
input_batch = torch.randn(32, 3, 224, 224, device='cuda')
target_labels = torch.randint(0, 1000, (32,), device='cuda')
optim = torch.optim.Adam(network.parameters())
loss_func = torch.nn.CrossEntropyLoss()

# Start recording memory history
torch.cuda.memory._record_memory_history(max_entries=MAX_EVENTS)

# Training loop with periodic snapshots
for epoch in range(10):
    predictions = network(input_batch)
    loss = loss_func(predictions, target_labels)
    loss.backward()
    optim.step()
    optim.zero_grad(set_to_none=True)

    # Save a snapshot after each epoch
    snapshot_file = f"{SNAPSHOT_PREFIX}_epoch_{epoch}.pickle"
    torch.cuda.memory._dump_snapshot(snapshot_file)

# Stop recording
torch.cuda.memory._record_memory_history(enabled=None)

Complete Implementation for ResNet-50

This example provides a more robust implementation with logging and utility functions.

import logging
import socket
from datetime import datetime
import torch
from torchvision import models

# Logging setup
logging.basicConfig(
    format='%(levelname)s:%(asctime)s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Constants
TIME_FORMAT = "%b_%d_%H_%M_%S"
MAX_MEMORY_EVENTS = 100000

def start_memory_recording() -> None:
    """Enables recording of CUDA memory events."""
    if not torch.cuda.is_available():
        logger.info("CUDA unavailable. Recording skipped.")
        return
    logger.info("Starting memory history recording.")
    torch.cuda.memory._record_memory_history(max_entries=MAX_MEMORY_EVENTS)

def stop_memory_recording() -> None:
    """Disables recording of CUDA memory events."""
    if not torch.cuda.is_available():
        return
    logger.info("Stopping memory history recording.")
    torch.cuda.memory._record_memory_history(enabled=None)

def save_memory_snapshot() -> None:
    """Exports the current memory state to a timestamped file."""
    if not torch.cuda.is_available():
        logger.info("CUDA unavailable. Snapshot export skipped.")
        return
    host = socket.gethostname()
    timestamp = datetime.now().strftime(TIME_FORMAT)
    filename = f"{host}_{timestamp}.pickle"
    try:
        torch.cuda.memory._dump_snapshot(filename)
        logger.info(f"Snapshot saved: {filename}")
    except Exception as err:
        logger.error(f"Failed to save snapshot: {err}")

def train_resnet(iterations=5, device='cuda:0'):
    """Trains a ResNet-50 model while capturing memory snapshots."""
    model = models.resnet50().to(device)
    sample_input = torch.randn(1, 3, 224, 224, device=device)
    sample_labels = torch.rand_like(model(sample_input))
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = torch.nn.CrossEntropyLoss()

    start_memory_recording()

    for step in range(iterations):
        output = model(sample_input)
        loss = criterion(output, sample_labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

    save_memory_snapshot()
    stop_memory_recording()

if __name__ == "__main__":
    train_resnet()

Enhanced Profiling with Torch Profiler

For more detailed analysis, including memory timelines correlated with operations, the PyTorch Profiler offers advanced capabilities.

import logging
import socket
from datetime import datetime
import torch
from torch.autograd.profiler import record_function
from torchvision import models

logging.basicConfig(
    format='%(levelname)s:%(asctime)s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
TIME_FORMAT = "%b_%d_%H_%M_%S"

def save_profiler_output(profiler):
    """Callback to save profiler traces and memory timelines."""
    host = socket.gethostname()
    timestamp = datetime.now().strftime(TIME_FORMAT)
    prefix = f"{host}_{timestamp}"
    profiler.export_chrome_trace(f"{prefix}.json.gz")
    profiler.export_memory_timeline(f"{prefix}.html", device="cuda:0")

def profile_resnet_training(iterations=5, device='cuda:0'):
    """Profiles the training loop of a ResNet-50 model."""
    model = models.resnet50().to(device)
    inputs = torch.randn(1, 3, 224, 224, device=device)
    labels = torch.rand_like(model(inputs))
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    loss_fn = torch.nn.CrossEntropyLoss()

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
        on_trace_ready=save_profiler_output,
    ) as prof:
        for i in range(iterations):
            prof.step()
            with record_function("forward_pass"):
                pred = model(inputs)
            with record_function("backward_pass"):
                loss_fn(pred, labels).backward()
            with record_function("optimizer_step"):
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

if __name__ == "__main__":
    # Warm-up run
    profile_resnet_training(iterations=1)
    # Main profiling run
    profile_resnet_training()
Tags: pytorchGPU

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.