ThunderKittens: A Minimal CUDA DSL for 30% H100 Performance Gain Over FlashAttention-2
AI’s rapid advancement brings massive computational demands, driving the need to reduce AI’s compute footprint and maximize existing hardware efficiency. Stanford researchers addressed this challenge by developing ThunderKittens, a compact CUDA-embedded DSL for writing high-performance deep learning kernels.
H100 SXM GPUs feature 80 GB HBM3 (3 TB/s), 50 MB L2 cache split into dual 25 MB banks (12 TB/s via crossbar), and 132 streaming multiprocessors (SMs). Extracting peak performance from these SMs requires navigating several hardware quirks:
- WGMMA (Warp Group Matrix Multiply Accumulate) instruction are mandatory for peak performance but asynchronous and complex, requiring coordination across 128 threads per warp group.
- Shared memory has 32 memory banks, which cause severe bank conflicts if data isn’t properly interleaved, despite a nominal 30-cycle latency.
- Address generation for memory access consumes significant chip resources, but NVIDIA’s Tensor Memory Accelerator (TMA) mitigates this by handling multidimensional tensor layouts.
- Occupancy remains useful but less critical than register pressure.
ThunderKittens simplifies kernel development through four core tile/vector abstractions:
- Register Tiles/Vectors: 2D/1D tensors in register files, parameterized by dimensions and layout.
- Shared Tiles/Vectors: 2D/1D tensors in shared memory, designed with bank-conflict-avoiding interleaving.
Key operations include initializatoin (zeroing, neg-infinity), unary (exp), binary (mul), and reductions (row_max, row_sum).
RTX 4090 FlashAttention Kernel (58 Lines)
#define WORKER_COUNT 16
using namespace tk;
__global__ void tk_flash_attend_64(int seq_len, const bf16* __restrict__ q_ptr, const bf16* __restrict__ k_ptr, const bf16* __restrict__ v_ptr, bf16* __restrict__ o_ptr) {
auto warp_id = tk::get_warp_id();
auto block_start = blockIdx.x * seq_len * 64;
const bf16* q = q_ptr + block_start;
const bf16* k = k_ptr + block_start;
const bf16* v = v_ptr + block_start;
bf16* o = o_ptr + block_start;
extern __shared__ alignment_pad __shm[];
tk_shared_allocator al(static_cast<int*>(&__shm[0]));
st_bf<1, 4, tk_layout::swizzle> (&k_shared)[WORKER_COUNT] = al.allocate<st_bf<1, 4, tk_layout::swizzle>, WORKER_COUNT>();
st_bf<1, 4, tk_layout::swizzle> (&v_shared)[WORKER_COUNT] = al.allocate<st_bf<1, 4, tk_layout::swizzle>, WORKER_COUNT>();
rt_bf<1, 4> q_reg, k_reg, v_reg;
rt_fl<1, 1> att_tile;
rt_bf<1, 1> att_mma;
rt_fl<1, 4> o_reg;
typename rt_fl<1, 1>::col_vec max_prev, max_curr;
typename rt_fl<1, 1>::col_vec norm_prev, norm_curr;
int qo_block_count = seq_len / (q_reg.rows * WORKER_COUNT);
int kv_block_count = qo_block_count;
for (int q_blk = 0; q_blk < qo_block_count; ++q_blk) {
tk::load(q_reg, q + (q_blk * WORKER_COUNT + warp_id) * q_reg.num_elements, q_reg.cols);
tk::mul(q_reg, q_reg, __float2bfloat16(0.125f));
tk::neg_infinity(max_curr);
tk::zero(norm_curr);
tk::zero(o_reg);
for (int kv_idx = 0; kv_idx < kv_block_count; ++kv_idx) {
tk::load(v_shared[warp_id], v + (kv_idx * WORKER_COUNT + warp_id) * q_reg.num_elements, q_reg.cols);
tk::load(k_shared[warp_id], k + (kv_idx * WORKER_COUNT + warp_id) * q_reg.num_elements, q_reg.cols);
__syncthreads();
for (int subtile = 0; subtile < WORKER_COUNT; ++subtile) {
tk::load(k_reg, k_shared[subtile]);
tk::zero(att_tile);
tk::mma_ABt(att_tile, q_reg, k_reg, att_tile);
tk::copy(norm_prev, norm_curr);
tk::copy(max_prev, max_curr);
tk::row_max(max_curr, att_tile, max_curr);
tk::sub_row(att_tile, att_tile, max_curr);
tk::exp(att_tile, att_tile);
tk::sub(max_prev, max_prev, max_curr);
tk::exp(max_prev, max_prev);
tk::mul(norm_curr, norm_curr, max_prev);
tk::row_sum(norm_curr, att_tile, norm_curr);
tk::div_row(att_tile, att_tile, norm_curr);
tk::copy(att_mma, att_tile);
rt_bf<1, 4, tk_layout::col>& v_col = tk::swap_layout_inplace(v_reg);
tk::load(v_col, v_shared[subtile]);
tk::mul_row(o_reg, o_reg, norm_prev);
tk::mma_AB(o_reg, att_mma, v_col, o_reg);
}
__syncthreads();
}
tk::store(o + (q_blk * WORKER_COUNT + warp_id) * q_reg.num_elements, o_reg, q_reg.cols);
}
}