Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Architectural Fundamentals of Sparse Mixture-of-Experts and Key-Value Caching

Tech May 15 1

Mixture-of-Experts Architecture Overview

The Mixture-of-Experts (MoE) paradigm fundamentally modifies the standard Transformer decoder layer by replacing the monolithic feed-forward network (FFN) with a dynamic routing mechanism and a collection of specialized sub-networks. This architecture consists of two primary components:

  1. Sparse Expert Layers: These replace traditional FFNs. Each expert operates as an independent neural module, typically implemented as a SwiGLU feed-forward block. Experts can be nested or arranged hierarchically, allowing for highly specialized representation learning.
  2. Gating and Routing Mechanism: A trainable projection layer evaluates each input token and computes a probability distribution over available experts. The router dynamically directs tokens to a selected subset of experts. Tokens may be assigned to one or multiple experts, and the routing weights scale their respective outputs before summation. Crucially, the router parameters are optimized jointly with the expert weights during pretraining.

The core distinction lies in activation density. In Dense MoE, every input token passes through all expert networks, resulting in significant computational overhead. Conversely, Sparse MoE activates only the top-k experts per token, drastically reducing FLOPs while scaling total parameter count.

Routing Mechanism Implementation

The routing layer computes assignment scores via a linear projection followed by an activation function. For sparse routing, only the highest-scoring experts participate in the forward pass. The normalized weights for the selected experts are computed using top-k selection and scaling.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class ExpertRouter(nn.Module):
    """
    Dynamic gating module for sparse expert routing.
    Computes assignment scores and selects top-k experts per token.
    """
    def __init__(self, hidden_dim: int, total_experts: int, active_k: int,
                 num_groups: int = 1, group_limit: int = 0, scale: float = 1.0,
                 use_sigmoid: bool = False):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.active_k = active_k
        self.num_groups = num_groups
        self.group_limit = group_limit
        self.scale_factor = scale
        self.activate_mode = "sigmoid" if use_sigmoid else "softmax"

        self.routing_matrix = nn.Parameter(torch.empty(total_experts, hidden_dim))
        nn.init.kaiming_uniform_(self.routing_matrix, a=0, nonlinearity='relu')

        self.bias = nn.Parameter(torch.zeros(total_experts)) if hidden_dim == 7168 else None

    def forward(self, token_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Linear projection: (bs, tokens, experts)
        raw_scores = torch.matmul(token_embeddings, self.routing_matrix.T)

        # Normalization step
        if self.activate_mode == "softmax":
            norm_scores = F.softmax(raw_scores, dim=-1, dtype=torch.float32)
        else:
            norm_scores = torch.sigmoid(raw_scores)

        original_norm = norm_scores.clone()

        if self.bias is not None:
            norm_scores = norm_scores + self.bias

        if self.num_groups > 1:
            # Group-aware routing with mask generation
            grouped = norm_scores.view(token_embeddings.size(0), self.num_groups, -1)
            if self.bias is None:
                group_vals = grouped.amax(dim=-1)
            else:
                top2 = grouped.topk(2, dim=-1).values.sum(dim=-1)
                group_vals = top2

            # Select top groups
            selected_groups = group_vals.topk(self.group_limit, dim=-1).indices
            group_mask = torch.zeros_like(grouped[..., 0]).scatter_(1, selected_groups, 1.0)
            norm_scores = (grouped * group_mask.unsqueeze(-1)).view(token_embeddings.size(0), -1)

        # Final top-k expert selection
        final_indices = torch.topk(norm_scores, self.active_k, dim=-1).indices
        assigned_weights = original_norm.gather(1, final_indices)

        if self.activate_mode == "sigmoid":
            assigned_weights = assigned_weights / assigned_weights.sum(dim=-1, keepdim=True)

        assigned_weights = assigned_weights * self.scale_factor
        return assigned_weights.to(token_embeddings.dtype), final_indices

Expert Feed-Forward Structure

Individual expert networks typically utilize the SwiGLU activation pattern. This consists of an up-projection, a gated linear unit, and a down-projection to restore dimensionality.

class SwiGLUExpert(nn.Module):
    def __init__(self, input_dim: int, intermediate_dim: int):
        super().__init__()
        self.gate_proj = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.up_proj = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.down_proj = nn.Linear(intermediate_dim, input_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

Integrated MoE Module

The complete MoE layer manages distributed expert allocation, executes forward passes for routed tokens, aggregates weighted outputs, and optionally incorporates shared dense pathways.

class SparseMoELayer(nn.Module):
    def __init__(self, config, device_rank: int, world_size: int):
        super().__init__()
        self.hidden_dim = config.hidden_dim
        self.num_total_experts = config.num_experts
        self.local_count = config.num_experts // world_size
        self.active_count = config.active_experts

        self.start_idx = device_rank * self.local_count
        self.end_idx = self.start_idx + self.local_count

        self.router = ExpertRouter(
            hidden_dim=config.hidden_dim,
            total_experts=config.num_experts,
            active_k=config.active_experts,
            num_groups=config.num_groups,
            group_limit=config.group_limit,
            scale=config.routing_scale
        )

        # Initialize local experts only; remote experts handled via dispatch
        self.local_experts = nn.ModuleList([
            SwiGLUExpert(config.hidden_dim, config.moe_inter_dim)
            for _ in range(self.local_count)
        ])

        self.shared_pathway = nn.Sequential(
            nn.Linear(config.hidden_dim, config.moe_inter_dim, bias=False),
            nn.SiLU(),
            nn.Linear(config.moe_inter_dim, config.hidden_dim, bias=False)
        )

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        orig_shape = inputs.shape
        flat_tokens = inputs.view(-1, self.hidden_dim)

        routing_weights, expert_indices = self.router(flat_tokens)

        # Buffer for aggregated outputs
        aggregated_output = torch.zeros_like(flat_tokens)
        token_counts = torch.bincount(expert_indices.view(-1), minlength=self.num_total_experts)

        for local_rank in range(self.start_idx, self.end_idx):
            if token_counts[local_rank] == 0:
                continue

            expert_module = self.local_experts[local_rank - self.start_idx]
            mask_rows, mask_cols = torch.where(expert_indices == local_rank)

            expert_out = expert_module(flat_tokens[mask_rows])
            weighted_out = expert_out * routing_weights[mask_rows, mask_cols, None]
            aggregated_output.index_add_(0, mask_rows, weighted_out)

        # Shared dense pathway contribution
        shared_out = self.shared_pathway(flat_tokens)

        # Cross-node synchronization if distributed
        if world_size > 1:
            import torch.distributed as dist
            dist.all_reduce(aggregated_output)

        return (aggregated_output + shared_out).view(orig_shape)

Expert Load Balancing

Sparse routing naturally introduces routing collapse, where tokens disproportionately favor specific experts. To maintain uniform computational distribution, auxiliary loss terms are introduced during training:

  1. Importance Variance Loss: Measures the squared coefficient of variation across expert importance scores. Minimizing this encourages uniform routing probability distribution.
  2. Actual Load Loss: Evaluates the real token assignment counts per expert across a batch. Stochastic sampling during routing probability estimation prevents deterministic collapse. The combined objective balances both variance metrics equally.

Alternatively, capacity constraints enforce a hard ceiling on tokens per expert. The threshold is typically defined as: (tokens_in_batch / num_experts) × capacity_factor. A factor between 1.0 and 1.25 provides flexibility while preventing severe overflow. Excess tokens bypass the MoE pathway via residual connections.

Key-Value Caching in Autoregressive Inference

Decoder-only large language models generate sequences token-by-token in an autoregressive manner. At step t, the model attends over the entire prefix x1:t to predict xt+1. Without optimization, each generation step recomputes the attention mechanism for all preceding tokens, leading to O(N²) complexity per step.

Observation reveals that for token position i, the query vector Qi interacts with all prior keys and values. The historical K and V matrices remain static once computed. Storing these historical states eliminates redundant matrix multiplications, trading GPU memory for substantial latency reduction.

Example implementation of a lightweight caching mechanism:

class AutoregressiveCache:
    def __init__(self, max_seq_len: int = 2048, device: torch.device = torch.device("cpu")):
        self.device = device
        self.cache_capacity = max_seq_len
        self.k_buffer: torch.Tensor | None = None
        self.v_buffer: torch.Tensor | None = None

    def append(self, k_states: torch.Tensor, v_states: torch.Tensor):
        if self.k_buffer is None:
            self.k_buffer = k_states.to(self.device)
            self.v_buffer = v_states.to(self.device)
        else:
            self.k_buffer = torch.cat([self.k_buffer, k_states], dim=1)
            self.v_buffer = torch.cat([self.v_buffer, v_states], dim=1)

    def fetch(self) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.k_buffer, self.v_buffer

    def clear(self):
        self.k_buffer = self.v_buffer = None

class CachedTransformerLayer(nn.Module):
    def __init__(self, embed_dim: int, n_heads: int, vocab_size: int):
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, embed_dim)
        self.mha = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
        self.proj_head = nn.Linear(embed_dim, vocab_size)
        self.cache = AutoregressiveCache()

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        embedded = self.tok_embed(input_ids)
        cached_k, cached_v = self.cache.fetch()

        if cached_k is not None:
            attn_out, _ = self.mha(embedded, cached_k, cached_v)
        else:
            attn_out, _ = self.mha(embedded, embedded, embedded)

        # Cache current states for subsequent steps
        self.cache.append(embedded, embedded)

        residual = embedded + attn_out
        return self.proj_head(residual)

Modern frameworks abstract this mechanism behind unified configuration flags. For instance, enabling use_cache=True in Hugging Face Transformers automatically allocates and manages the buffer during generation:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

repo_id = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.float16).to("cuda")

prompt_ids = tokenizer("Analyzing sequence efficiency:", return_tensors="pt").to("cuda")
# Cache-enabled generation
gen_out_cached = model.generate(**prompt_ids, max_new_tokens=50, use_cache=True, do_sample=False)
print(tokenizer.decode(gen_out_cached[0], skip_special_tokens=True))

# Generation without caching (recomputes KV at every step)
gen_out_fresh = model.generate(**prompt_ids, max_new_tokens=50, use_cache=False, do_sample=False)

Cache Implementation Variants

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.