Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Backpropagation Mechanics, Higher-Order Derivatives, and Multi-GPU Model Partitioning

Tech May 11 2

Neural network training relies on two distinct phases within the computational graph. Forward propagation sequences calculations from the input layer toward the output, storing intermediate states. Conversely, backpropagation traverses the graph in reverse, computing gradients for parameters and intermediate variables starting from the loss function at the output layer. These processes are intrinsically linked during model optimization, though the training phase demands significantly more memory than inference due to the need to retain intermediate activations for gradient calculation.

Incorporating Bias Terms in Hidden Layers

When extending a multilayer perceptron to include bias vectors in hidden layers, the computational flow adjusts slightly while excluding biases from regularization penalties.

Forward Pass Formulation Consider an input matrix $\mathbf{X}$ with dimensions $n \times m$.

  1. Hidden Layer Pre-activation: $$ \mathbf{H} = \mathbf{X}\mathbf{W}_1 + \mathbf{b}_1 $$ Here, $\mathbf{W}_1$ represents weights ($m \times h$) and $\mathbf{b}_1$ is the bias vector ($1 \times h$). The resulting $\mathbf{H}$ has shape $n \times h$.
  2. Activation: $$ \mathbf{A} = \phi(\mathbf{H}) $$ $\mathbf{A}$ retains the $n \times h$ dimensions.
  3. Output Layer: $$ \mathbf{O} = \mathbf{A}\mathbf{W}_2 + \mathbf{b}_2 $$ Weights $\mathbf{W}_2$ are $h \times o$, bias $\mathbf{b}_2$ is $1 \times o$, yielding output $\mathbf{O}$ of size $n \times o$.

Backward Pass Derivation Given a loss function $\mathcal{L}(\mathbf{O}, \mathbf{T})$ where $\mathbf{T}$ is the target ($n \times o$):

  1. Output Layer Gradients:

    • Loss gradient w.r.t output: $\frac{\partial \mathcal{L}}{\partial \mathbf{O}}$ ($n \times o$).
    • Weight gradient: $\frac{\partial \mathcal{L}}{\partial \mathbf{W}_2} = \mathbf{A}^\top \frac{\partial \mathcal{L}}{\partial \mathbf{O}}$ ($h \times o$).
    • Bias gradient: $\frac{\partial \mathcal{L}}{\partial \mathbf{b}2} = \sum{i=1}^{n} \frac{\partial \mathcal{L}}{\partial \mathbf{O}_i}$ ($1 \times o$).
    • Activation gradient: $\frac{\partial \mathcal{L}}{\partial \mathbf{A}} = \frac{\partial \mathcal{L}}{\partial \mathbf{O}} \mathbf{W}_2^\top$ ($n \times h$).
  2. Hidden Layer Gradients:

    • Pre-activation gradient: $\frac{\partial \mathcal{L}}{\partial \mathbf{H}} = \frac{\partial \mathcal{L}}{\partial \mathbf{A}} \odot \phi'(\mathbf{H})$ ($n \times h$).
    • Weight gradient: $\frac{\partial \mathcal{L}}{\partial \mathbf{W}_1} = \mathbf{X}^\top \frac{\partial \mathcal{L}}{\partial \mathbf{H}}$ ($m \times h$).
    • Bias gradient: $\frac{\partial \mathcal{L}}{\partial \mathbf{b}1} = \sum{i=1}^{n} \frac{\partial \mathcal{L}}{\partial \mathbf{H}_i}$ ($1 \times h$).

Computational Complexity of Second-Order Derivatives

Calculating second-order derivatives (Hessian matrix) significantly alters the computational graph requirements. The process involves:

  1. Standard forward pass to compute loss $\mathcal{L}$.
  2. Standard backward pass to obtain first-order gradients $\nabla \mathcal{L}$.
  3. A secondary backward pass performed on each component of the first-order gradient to populate the Hessian $\mathbf{H}_{\mathcal{L}}$.

This operation increases complexity to approximately $O(n^2)$ relative to the parameter count, leading to substantial increases in execution time and memory consumption compared to first-order optimizaton.

Managing Memory Constraints via Model Parallelism

When a computational graph exceeds the VRAM capacity of a single GPU, the model architecture can be partitioned across multiple devices.

Implementation Strategy The network layers are distributed sequentially across available hardware. Intermediate activations must be transferred between devices during both forward and backward passes.

import torch
import torch.nn as nn

class DistributedMLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        # Define specific devices for each stage
        self.dev_1 = torch.device('cuda:0')
        self.dev_2 = torch.device('cuda:1')
        
        # Assign layers to specific devices
        self.layer_stage_1 = nn.Linear(in_features, hidden_features).to(self.dev_1)
        self.layer_stage_2 = nn.Linear(hidden_features, out_features).to(self.dev_2)
        
    def forward(self, x):
        # Move input to first device
        x = x.to(self.dev_1)
        h = torch.relu(self.layer_stage_1(x))
        
        # Transfer intermediate tensor to second device
        h = h.to(self.dev_2)
        out = self.layer_stage_2(h)
        return out

# Instantiate and run
model = DistributedMLP(2048, 1024, 10)
dummy_input = torch.randn(64, 2048)
output = model(dummy_input)

Synchronization primitives like torch.cuda.synchronize() ensure operations complete before proceeding, preventing race conditions during data transfer.

Trade-offs: Model Parallelism vs. Mini-batch Training

Model Parallelism Advantages:

  • Capacity: Enables training of architectures too large for single-device memory.
  • Resource Utilization: Leverages aggregate VRAM and compute power across multiple GPUs.
  • Memory Efficiency: Distributes parameter storage, mitigating OOM errors on individual cards.

Model Parallelism Disadvantages:

  • Communication Latency: Frequent tensor transfers between devices introduce overhead, potentially slowing down iteration times.
  • Engineering Oevrhead: Requires careful management of device placement and data movement, increasing code complexity.
  • Load Balancing: Uneven layer sizes can cause idle time on faster devices while waiting for slower stages.

Mini-batch Training Comparison:

  • Pros: Simpler implementation, stable gradient estimation through aggregation, lower communication needs within a single device.
  • Cons: Limited by single-device memory capacity, potentially slower per-iteration throughput on large datasets compared to distributed compute setups.

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.