Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Qwen2 Transformer Architecture: A Comprehensive Technical Breakdown

Tech May 13 2

The Qwen2 model is hosted in the QwenLM/Qwen2 GitHub repository and has been integrated into Hugging Face Transformers starting from version 4.37.0, with its implementation located in the transformers/models/qwen2 directory. Like its predecessor Qwen, Qwen2 follows a decoder-only Transformer architecture, leveraging components such as RMSNorm, the SwiGLU activation function, Rotary Position Embedding (RoPE), and multi-head attention mechanisms.

RMS Normalization

Qwen2 uses Root Mean Square Layer Normalization (RMSNorm) for layer normalization. This technique normalizes the output of a neural network layer using the root mean square (RMS) of the layer's activations, followed by scaling with a learnable gain parameter. The formula for RMSNorm is:

$$\bar{a}i = \frac{a_i}{\text{RMS}(\mathbf{a})} g_i, \quad \text{where } \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n} \sum{i=1}^n a_i^2}$$

Here, $\mathbf{a} \in \mathbb{R}^n$ is the output vector of the layer, and $\mathbf{g} \in \mathbb{R}^n$ is the learnable gain parameter initialized to 1. The implementation of Qwen2RMSNorm is as follows:

import torch
import torch.nn as nn

class Qwen2RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """RMSNorm implementation for Qwen2, equivalent to T5LayerNorm"""
        super().__init__()
        self.scale = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        # Convert to float32 to avoid numerical instability
        hiddens_float = hidden_states.to(torch.float32)
        # Compute RMS of the hidden states
        rms = torch.sqrt(torch.mean(hiddens_float.pow(2), dim=-1, keepdim=True))
        # Normalize and scale back to original dtype
        normalized = hiddens_float / (rms + self.eps)
        return self.scale * normalized.to(input_dtype)

Placement of RMSNorm

Qwen2 applies RMSNorm in three key locations within its architecture:

  1. Before Self-Attention: In each decoder layer, RMSNorm is applied to the input hidden states before feeding them into the self-attention sublayer:
class Qwen2DecoderLayer(nn.Module):
    def __init__(self, config, layer_idx: int):
        super().__init__()
        self.input_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # Other initializations omitted for brevity

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor = None,
        position_ids: torch.LongTensor = None,
        past_key_value: tuple[torch.Tensor] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: torch.LongTensor = None,
        **kwargs
    ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]:
        residual = hidden_states
        hidden_states = self.input_norm(hidden_states)

        # Self-Attention computation
        attn_output, attn_weights, present_kv = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )
        hidden_states = residual + attn_output
        # Rest of forward pass omitted
  1. Before Feed-Forward Network: After the self-attention sublayer, RMSNorm is applied again before the hidden states enter the feed-forward network (FFN):
class Qwen2DecoderLayer(nn.Module):
    def __init__(self, config, layer_idx: int):
        super().__init__()
        self.post_attn_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # Other initializations omitted for brevity

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor = None,
        position_ids: torch.LongTensor = None,
        past_key_value: tuple[torch.Tensor] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: torch.LongTensor = None,
        **kwargs
    ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]:
        # Self-Attention pass omitted

        # Feed-Forward Network pass
        residual = hidden_states
        hidden_states = self.post_attn_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        # Rest of forward pass omitted
  1. After Final Decoder Layer: RMSNorm is applied to the output of the last decoder layer to produce the final hidden states:
class Qwen2Model(Qwen2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.final_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # Other initializations omitted for brevity

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: torch.Tensor = None,
        position_ids: torch.LongTensor = None,
        past_key_values: list[torch.FloatTensor] = None,
        inputs_embeds: torch.FloatTensor = None,
        use_cache: bool = None,
        output_attentions: bool = None,
        output_hidden_states: bool = None,
        return_dict: bool = None,
        cache_position: torch.LongTensor = None,
    ) -> tuple[torch.FloatTensor, list[torch.FloatTensor]]:
        # Decoder layer processing omitted

        hidden_states = self.final_norm(hidden_states)
        # Rest of forward pass omitted

Position Encoding with RoPE

Qwen2 uses Rotary Position Embedding (RoPE) to inject position information into the attention mechanism. Unlike traditional position embeddings that are added to token embeddings, RoPE encodes relative position information directly into the self-attention calculations by rotating query and key vectors using trigonometric functions.

The RoPE implementation for Qwen2 precomputes cosine and sine values for positions, which are then applied to query and key tensors:

class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, head_dim: int, max_seq_len: int = 2048, base: int = 10000, device=None):
        super().__init__()
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Compute inverse frequency for rotary embeddings
        inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Precompute cosine and sine cache
        self._update_cos_sin_cache(max_seq_len, device, torch.get_default_dtype())

    def _update_cos_sin_cache(self, seq_len: int, device, dtype):
        self.cache_seq_len = seq_len
        positions = torch.arange(seq_len, device=device, dtype=torch.float32)

        # Compute frequency values
        freqs = torch.outer(positions, self.inv_freq)
        # Duplicate frequencies to match head dimension
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cache", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cache", emb.sin().to(dtype), persistent=False)

    def forward(self, x: torch.Tensor, current_seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
        # Update cache if current sequence length exceeds cached max
        if current_seq_len > self.cache_seq_len:
            self._update_cos_sin_cache(current_seq_len, x.device, x.dtype)
        
        return (
            self.cos_cache[:current_seq_len].to(x.dtype),
            self.sin_cache[:current_seq_len].to(x.dtype)
        )

The apply_rotary_pos_emb function applies the precomputed cosine and sine values to rotate query and key tensors:

def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """Splits tensor into two halves and rotates the second half"""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)

def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    position_ids: torch.LongTensor,
    unsqueeze_dim: int = 1
) -> tuple[torch.Tensor, torch.Tensor]:
    """Applies Rotary Position Embedding to query and key tensors"""
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_rotated = (q * cos) + (rotate_half(q) * sin)
    k_rotated = (k * cos) + (rotate_half(k) * sin)
    return q_rotated, k_rotated

Note that Qwen2's RoPE implementation differs slightly from the original paper: instead of rotating consecutive pairs of elements in the embedding vector, it splits the vector into two equal halves and rotates the second half relative to the first. This alternative arrangement is functionally equivalent to the paper's method and simplifies implementation.

SwiGLU Activation Function

Qwen2 uses the SwiGLU activation function in its feed-forward network (FFN). SwiGLU is a gated linear unit variant that replaces the sigmoid function in GLU with the SiLU (Sigmoid Linear Unit, also called Swish) function. The SiLU function is defined as $\text{SiLU}(x) = x \cdot \sigma(x)$, where $\sigma(x)$ is the sigmoid function.

The FFN implementation for Qwen2 is as follows:

from transformers.activations import ACT2FN

class Qwen2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.hidden_size
        self.intermediate_dim = config.intermediate_size
        self.gate_linear = nn.Linear(self.d_model, self.intermediate_dim, bias=False)
        self.up_linear = nn.Linear(self.d_model, self.intermediate_dim, bias=False)
        self.down_linear = nn.Linear(self.intermediate_dim, self.d_model, bias=False)
        self.activation = ACT2FN[config.hidden_act]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_output = self.activation(self.gate_linear(x))
        up_output = self.up_linear(x)
        return self.down_linear(gate_output * up_output)

In Qwen2's configuration, config.hidden_act is set to "silu", so the activation function used is SiLU. The FFN computes the output by multiplying the gated activation output with the up-projection output, then projecting the result back to the model's hidden size.

Self-Attention Mechanisms

Qwen2 supports multiple implementations and variants of self-attention to balance performance and flexibility.

Attention Implementations

The model offers three options for attention computation:

QWEN2_ATTENTION_CLASSES = {
    "eager": Qwen2Attention,
    "flash_attention_2": Qwen2FlashAttention2,
    "sdpa": Qwen2SdpaAttention,
}
  • Eager: A custom, manually implemented multi-head attention layer for maximum compatibility.
  • FlashAttention-2: An optimized implementation using the FlashAttention2 library, which suppports sliding window attention and improves memory efficiency.
  • SDPA: Uses PyTorch's native scaled_dot_product_attention function for hardware-accelerated attention computations.

Attention Variants

Qwen2 supports three self-attention variants, controlled by the num_key_value_heads configuration parameter:

  • Multi-Head Attention (MHA): When num_key_value_heads equals num_attention_heads, each query head has its own dedicated key and value heads.
  • Multi-Query Attention (MQA): When num_key_value_heads is set to 1, all query heads share a single pair of key and value heads.
  • Grouped-Query Attention (GQA): When num_key_value_heads is betweeen 1 and num_attention_heads, query heads are divided into groups, with each group sharing a single key-value head pair.

The attention layer initialization snippet below shows how these variants are handled:

class Qwen2Attention(nn.Module):
    def __init__(self, config, layer_idx: int = None):
        super().__init__()
        self.d_model = config.hidden_size
        self.n_heads = config.num_attention_heads
        self.head_dim = self.d_model // self.n_heads
        self.n_kv_heads = config.num_key_value_heads
        self.n_groups = self.n_heads // self.n_kv_heads
        # Other initializations omitted

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.