Qwen2 Transformer Architecture: A Comprehensive Technical Breakdown
The Qwen2 model is hosted in the QwenLM/Qwen2 GitHub repository and has been integrated into Hugging Face Transformers starting from version 4.37.0, with its implementation located in the transformers/models/qwen2 directory. Like its predecessor Qwen, Qwen2 follows a decoder-only Transformer architecture, leveraging components such as RMSNorm, the SwiGLU activation function, Rotary Position Embedding (RoPE), and multi-head attention mechanisms.
RMS Normalization
Qwen2 uses Root Mean Square Layer Normalization (RMSNorm) for layer normalization. This technique normalizes the output of a neural network layer using the root mean square (RMS) of the layer's activations, followed by scaling with a learnable gain parameter. The formula for RMSNorm is:
$$\bar{a}i = \frac{a_i}{\text{RMS}(\mathbf{a})} g_i, \quad \text{where } \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n} \sum{i=1}^n a_i^2}$$
Here, $\mathbf{a} \in \mathbb{R}^n$ is the output vector of the layer, and $\mathbf{g} \in \mathbb{R}^n$ is the learnable gain parameter initialized to 1. The implementation of Qwen2RMSNorm is as follows:
import torch
import torch.nn as nn
class Qwen2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""RMSNorm implementation for Qwen2, equivalent to T5LayerNorm"""
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
# Convert to float32 to avoid numerical instability
hiddens_float = hidden_states.to(torch.float32)
# Compute RMS of the hidden states
rms = torch.sqrt(torch.mean(hiddens_float.pow(2), dim=-1, keepdim=True))
# Normalize and scale back to original dtype
normalized = hiddens_float / (rms + self.eps)
return self.scale * normalized.to(input_dtype)
Placement of RMSNorm
Qwen2 applies RMSNorm in three key locations within its architecture:
- Before Self-Attention: In each decoder layer, RMSNorm is applied to the input hidden states before feeding them into the self-attention sublayer:
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
self.input_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Other initializations omitted for brevity
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None,
position_ids: torch.LongTensor = None,
past_key_value: tuple[torch.Tensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor = None,
**kwargs
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]:
residual = hidden_states
hidden_states = self.input_norm(hidden_states)
# Self-Attention computation
attn_output, attn_weights, present_kv = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = residual + attn_output
# Rest of forward pass omitted
- Before Feed-Forward Network: After the self-attention sublayer, RMSNorm is applied again before the hidden states enter the feed-forward network (FFN):
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
self.post_attn_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Other initializations omitted for brevity
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None,
position_ids: torch.LongTensor = None,
past_key_value: tuple[torch.Tensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor = None,
**kwargs
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]:
# Self-Attention pass omitted
# Feed-Forward Network pass
residual = hidden_states
hidden_states = self.post_attn_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
# Rest of forward pass omitted
- After Final Decoder Layer: RMSNorm is applied to the output of the last decoder layer to produce the final hidden states:
class Qwen2Model(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.final_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Other initializations omitted for brevity
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
position_ids: torch.LongTensor = None,
past_key_values: list[torch.FloatTensor] = None,
inputs_embeds: torch.FloatTensor = None,
use_cache: bool = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
cache_position: torch.LongTensor = None,
) -> tuple[torch.FloatTensor, list[torch.FloatTensor]]:
# Decoder layer processing omitted
hidden_states = self.final_norm(hidden_states)
# Rest of forward pass omitted
Position Encoding with RoPE
Qwen2 uses Rotary Position Embedding (RoPE) to inject position information into the attention mechanism. Unlike traditional position embeddings that are added to token embeddings, RoPE encodes relative position information directly into the self-attention calculations by rotating query and key vectors using trigonometric functions.
The RoPE implementation for Qwen2 precomputes cosine and sine values for positions, which are then applied to query and key tensors:
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, max_seq_len: int = 2048, base: int = 10000, device=None):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
# Compute inverse frequency for rotary embeddings
inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Precompute cosine and sine cache
self._update_cos_sin_cache(max_seq_len, device, torch.get_default_dtype())
def _update_cos_sin_cache(self, seq_len: int, device, dtype):
self.cache_seq_len = seq_len
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
# Compute frequency values
freqs = torch.outer(positions, self.inv_freq)
# Duplicate frequencies to match head dimension
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cache", emb.sin().to(dtype), persistent=False)
def forward(self, x: torch.Tensor, current_seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
# Update cache if current sequence length exceeds cached max
if current_seq_len > self.cache_seq_len:
self._update_cos_sin_cache(current_seq_len, x.device, x.dtype)
return (
self.cos_cache[:current_seq_len].to(x.dtype),
self.sin_cache[:current_seq_len].to(x.dtype)
)
The apply_rotary_pos_emb function applies the precomputed cosine and sine values to rotate query and key tensors:
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Splits tensor into two halves and rotates the second half"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.LongTensor,
unsqueeze_dim: int = 1
) -> tuple[torch.Tensor, torch.Tensor]:
"""Applies Rotary Position Embedding to query and key tensors"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_rotated = (q * cos) + (rotate_half(q) * sin)
k_rotated = (k * cos) + (rotate_half(k) * sin)
return q_rotated, k_rotated
Note that Qwen2's RoPE implementation differs slightly from the original paper: instead of rotating consecutive pairs of elements in the embedding vector, it splits the vector into two equal halves and rotates the second half relative to the first. This alternative arrangement is functionally equivalent to the paper's method and simplifies implementation.
SwiGLU Activation Function
Qwen2 uses the SwiGLU activation function in its feed-forward network (FFN). SwiGLU is a gated linear unit variant that replaces the sigmoid function in GLU with the SiLU (Sigmoid Linear Unit, also called Swish) function. The SiLU function is defined as $\text{SiLU}(x) = x \cdot \sigma(x)$, where $\sigma(x)$ is the sigmoid function.
The FFN implementation for Qwen2 is as follows:
from transformers.activations import ACT2FN
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.d_model = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_linear = nn.Linear(self.d_model, self.intermediate_dim, bias=False)
self.up_linear = nn.Linear(self.d_model, self.intermediate_dim, bias=False)
self.down_linear = nn.Linear(self.intermediate_dim, self.d_model, bias=False)
self.activation = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_output = self.activation(self.gate_linear(x))
up_output = self.up_linear(x)
return self.down_linear(gate_output * up_output)
In Qwen2's configuration, config.hidden_act is set to "silu", so the activation function used is SiLU. The FFN computes the output by multiplying the gated activation output with the up-projection output, then projecting the result back to the model's hidden size.
Self-Attention Mechanisms
Qwen2 supports multiple implementations and variants of self-attention to balance performance and flexibility.
Attention Implementations
The model offers three options for attention computation:
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
"sdpa": Qwen2SdpaAttention,
}
- Eager: A custom, manually implemented multi-head attention layer for maximum compatibility.
- FlashAttention-2: An optimized implementation using the FlashAttention2 library, which suppports sliding window attention and improves memory efficiency.
- SDPA: Uses PyTorch's native
scaled_dot_product_attentionfunction for hardware-accelerated attention computations.
Attention Variants
Qwen2 supports three self-attention variants, controlled by the num_key_value_heads configuration parameter:
- Multi-Head Attention (MHA): When
num_key_value_headsequalsnum_attention_heads, each query head has its own dedicated key and value heads. - Multi-Query Attention (MQA): When
num_key_value_headsis set to 1, all query heads share a single pair of key and value heads. - Grouped-Query Attention (GQA): When
num_key_value_headsis betweeen 1 andnum_attention_heads, query heads are divided into groups, with each group sharing a single key-value head pair.
The attention layer initialization snippet below shows how these variants are handled:
class Qwen2Attention(nn.Module):
def __init__(self, config, layer_idx: int = None):
super().__init__()
self.d_model = config.hidden_size
self.n_heads = config.num_attention_heads
self.head_dim = self.d_model // self.n_heads
self.n_kv_heads = config.num_key_value_heads
self.n_groups = self.n_heads // self.n_kv_heads
# Other initializations omitted