Estimating GPU Memory Consumption and Parameter Counts in PyTorch Models
When deploying large language models such as LLaMA-7B, determining video memory requirements becomes critical. In standard FP32 precision, each trainable parameter consumes 4 bytes of storage. Calculating total VRAM usage follows the formula: Total Parameters × 4 Bytes. For accurate estimation, note that 1 MB = 1024² Bytes. Below demonstrates three distinct approaches to quantify parameter counts and memory footprints using a custom convolutional architecture.
Custom Network Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiScaleModule(nn.Module):
def __init__(self, input_channels):
super(MultiScaleModule, self).__init__()
# Pathway 1: 1x1 projection
self.path_1x1 = nn.Conv2d(input_channels, 12, kernel_size=1)
# Pathway 2: 3x3 convolution with reduction
self.path_3x3_reduce = nn.Conv2d(input_channels, 10, kernel_size=1)
self.path_3x3_expand = nn.Conv2d(10, 18, kernel_size=3, padding=1)
# Pathway 3: 5x5 convolution with reduction
self.path_5x5_reduce = nn.Conv2d(input_channels, 8, kernel_size=1)
self.path_5x5_expand = nn.Conv2d(8, 14, kernel_size=5, padding=2)
# Pathway 4: Pooling projection
self.path_pool = nn.Conv2d(input_channels, 8, kernel_size=1)
def forward(self, x):
# Four parallel processing streams
stream_a = self.path_1x1(x)
stream_b = self.path_3x3_reduce(x)
stream_b = self.path_3x3_expand(stream_b)
stream_c = self.path_5x5_reduce(x)
stream_c = self.path_5x5_expand(stream_c)
stream_d = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
stream_d = self.path_pool(stream_d)
# Concatenate along channel dimension (52 = 12+18+14+8)
return torch.cat([stream_a, stream_b, stream_c, stream_d], dim=1)
class HybridNetwork(nn.Module):
def __init__(self):
super(HybridNetwork, self).__init__()
self.entry_conv = nn.Conv2d(1, 6, kernel_size=5)
self.stage_a = MultiScaleModule(6)
# Transition convolution: 52 -> 12 channels
self.transition = nn.Conv2d(52, 12, kernel_size=5)
self.stage_b = MultiScaleModule(12)
self.spatial_reduce = nn.MaxPool2d(kernel_size=2)
# After two pooling operations on 28x28 with kernel=5:
# (28-4)/2 = 12, (12-4)/2 = 4, resulting in 4x4 spatial dimensions
# 52 channels * 4 * 4 = 832 features
self.predictor = nn.Linear(832, 10)
def forward(self, x):
batch_dim = x.size(0)
x = F.relu(self.spatial_reduce(self.entry_conv(x)))
x = self.stage_a(x)
x = F.relu(self.spatial_reduce(self.transition(x)))
x = self.stage_b(x)
x = x.view(batch_dim, -1)
return self.predictor(x)
Approach 1: Aggregated Memory Analysis
def analyze_memory_footprint(network, bytes_per_element=4):
"""
Computes total and trainable parameter memory requirements.
Args:
network: Instantiated PyTorch model
bytes_per_element: Storage size per parameter (4 for FP32, 2 for FP16/BF16)
"""
param_elements = sum(layer.numel() for layer in network.parameters())
trainable_elements = sum(layer.numel() for layer in network.parameters() if layer.requires_grad)
total_megabytes = (param_elements * bytes_per_element) / (1024 ** 2)
trainable_megabytes = (trainable_elements * bytes_per_element) / (1024 ** 2)
print(f"Total Parameters: {param_elements:,}")
print(f"Total Memory: {total_megabytes:.2f} MB")
print(f"Trainable Parameters: {trainable_elements:,}")
print(f"Trainable Memory: {trainable_megabytes:.2f} MB")
model = HybridNetwork()
analyze_memory_footprint(model)
Approach 2: Iterative Dimension Enumeration
def enumerate_parameters(model):
"""Manual calculation by iterating parameter tensor dimensions."""
weight_tensors = list(model.parameters())
cumulative_params = 0
for tensor_idx, weight_matrix in enumerate(weight_tensors):
dimension_volume = 1
for axis_size in weight_matrix.size():
dimension_volume *= axis_size
cumulative_params += dimension_volume
print(f"Enumerated Parameter Count: {cumulative_params}")
enumerate_parameters(model)
Approach 3: Automated Profiling Libraires
Using torchstat
from torchstat import stat
analysis_model = HybridNetwork()
stat(analysis_model, (1, 28, 28))
Using thop
from thop import profile
evaluation_model = HybridNetwork()
dummy_input = torch.randn(1, 1, 28, 28)
flops, param_count = profile(evaluation_model, inputs=(dummy_input,))
print(f"Computational Cost: {flops/1e6:.2f} MFLOPs")
print(f"Parameter Count: {int(param_count):,}")
The first approach provides reusable utility functions for quick estimation. The second illustrates the underlying mechanism of PyTorch's parameter storage. The third offers comprehensive profiling including FLOPs analysis, though requires careful specification of input tensor dimensions matching the model's expected input shape.