Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

ThunderKittens: A Minimal CUDA DSL for 30% H100 Performance Gain Over FlashAttention-2

Tech 1

AI’s rapid advancement brings massive computational demands, driving the need to reduce AI’s compute footprint and maximize existing hardware efficiency. Stanford researchers addressed this challenge by developing ThunderKittens, a compact CUDA-embedded DSL for writing high-performance deep learning kernels.

H100 SXM GPUs feature 80 GB HBM3 (3 TB/s), 50 MB L2 cache split into dual 25 MB banks (12 TB/s via crossbar), and 132 streaming multiprocessors (SMs). Extracting peak performance from these SMs requires navigating several hardware quirks:

  • WGMMA (Warp Group Matrix Multiply Accumulate) instruction are mandatory for peak performance but asynchronous and complex, requiring coordination across 128 threads per warp group.
  • Shared memory has 32 memory banks, which cause severe bank conflicts if data isn’t properly interleaved, despite a nominal 30-cycle latency.
  • Address generation for memory access consumes significant chip resources, but NVIDIA’s Tensor Memory Accelerator (TMA) mitigates this by handling multidimensional tensor layouts.
  • Occupancy remains useful but less critical than register pressure.

ThunderKittens simplifies kernel development through four core tile/vector abstractions:

  • Register Tiles/Vectors: 2D/1D tensors in register files, parameterized by dimensions and layout.
  • Shared Tiles/Vectors: 2D/1D tensors in shared memory, designed with bank-conflict-avoiding interleaving.

Key operations include initializatoin (zeroing, neg-infinity), unary (exp), binary (mul), and reductions (row_max, row_sum).

RTX 4090 FlashAttention Kernel (58 Lines)

#define WORKER_COUNT 16
using namespace tk;

__global__ void tk_flash_attend_64(int seq_len, const bf16* __restrict__ q_ptr, const bf16* __restrict__ k_ptr, const bf16* __restrict__ v_ptr, bf16* __restrict__ o_ptr) {
    auto warp_id = tk::get_warp_id();
    auto block_start = blockIdx.x * seq_len * 64;
    const bf16* q = q_ptr + block_start;
    const bf16* k = k_ptr + block_start;
    const bf16* v = v_ptr + block_start;
    bf16* o = o_ptr + block_start;

    extern __shared__ alignment_pad __shm[];
    tk_shared_allocator al(static_cast<int*>(&__shm[0]));

    st_bf<1, 4, tk_layout::swizzle> (&k_shared)[WORKER_COUNT] = al.allocate<st_bf<1, 4, tk_layout::swizzle>, WORKER_COUNT>();
    st_bf<1, 4, tk_layout::swizzle> (&v_shared)[WORKER_COUNT] = al.allocate<st_bf<1, 4, tk_layout::swizzle>, WORKER_COUNT>();

    rt_bf<1, 4> q_reg, k_reg, v_reg;
    rt_fl<1, 1> att_tile;
    rt_bf<1, 1> att_mma;
    rt_fl<1, 4> o_reg;
    typename rt_fl<1, 1>::col_vec max_prev, max_curr;
    typename rt_fl<1, 1>::col_vec norm_prev, norm_curr;

    int qo_block_count = seq_len / (q_reg.rows * WORKER_COUNT);
    int kv_block_count = qo_block_count;

    for (int q_blk = 0; q_blk < qo_block_count; ++q_blk) {
        tk::load(q_reg, q + (q_blk * WORKER_COUNT + warp_id) * q_reg.num_elements, q_reg.cols);
        tk::mul(q_reg, q_reg, __float2bfloat16(0.125f));

        tk::neg_infinity(max_curr);
        tk::zero(norm_curr);
        tk::zero(o_reg);

        for (int kv_idx = 0; kv_idx < kv_block_count; ++kv_idx) {
            tk::load(v_shared[warp_id], v + (kv_idx * WORKER_COUNT + warp_id) * q_reg.num_elements, q_reg.cols);
            tk::load(k_shared[warp_id], k + (kv_idx * WORKER_COUNT + warp_id) * q_reg.num_elements, q_reg.cols);
            __syncthreads();

            for (int subtile = 0; subtile < WORKER_COUNT; ++subtile) {
                tk::load(k_reg, k_shared[subtile]);
                tk::zero(att_tile);
                tk::mma_ABt(att_tile, q_reg, k_reg, att_tile);

                tk::copy(norm_prev, norm_curr);
                tk::copy(max_prev, max_curr);

                tk::row_max(max_curr, att_tile, max_curr);
                tk::sub_row(att_tile, att_tile, max_curr);
                tk::exp(att_tile, att_tile);

                tk::sub(max_prev, max_prev, max_curr);
                tk::exp(max_prev, max_prev);
                tk::mul(norm_curr, norm_curr, max_prev);

                tk::row_sum(norm_curr, att_tile, norm_curr);
                tk::div_row(att_tile, att_tile, norm_curr);

                tk::copy(att_mma, att_tile);
                rt_bf<1, 4, tk_layout::col>& v_col = tk::swap_layout_inplace(v_reg);
                tk::load(v_col, v_shared[subtile]);

                tk::mul_row(o_reg, o_reg, norm_prev);
                tk::mma_AB(o_reg, att_mma, v_col, o_reg);
            }
            __syncthreads();
        }
        tk::store(o + (q_blk * WORKER_COUNT + warp_id) * q_reg.num_elements, o_reg, q_reg.cols);
    }
}

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.