I made a kernel 2.2x faster. It made my training loop 3x slower

(kyrieblunders.bearblog.dev)

15 points | by vishal-padia 2 days ago

1 comments

  • vishal-padia 2 days ago
    Quick context on what's in the post:

    1. From scratch Dr. GRPO implementation in ~300 lines of PyTorch (Qwen2.5-0.5B on GSM8K, A10G). 2. Profiling deep dive on the training loop. Generate is 90% of step time. Pre-allocating the KV cache via StaticCache took GPU utilization from 26% to 86%, biggest single win in the project. 3. Wrote a fused decode-attention kernel in CuteDSL (RoPE + KV cache write + attention in one launch). Benchmarks 2.2x faster than the SDPA path it replaces at the relevant scale. 4. Plugged it into HF generate and the decode step got 3x slower. The post is mostly about why this happened, what it took to figure out, and what would actually close the gap.

    Happy to answer questions.