FlashAttention (FA1–FA4) in PyTorch - educational implementations focused on algorithmic differences [P]

Reddit r/MachineLearning / 4/12/2026

💬 OpinionIdeas & Deep AnalysisTools & Practical UsageModels & Research

Key Points

  • The author updated a FlashAttention-PyTorch repository to include educational (non-optimized, plain-PyTorch) implementations of FlashAttention variants FA1 through FA4.
  • The repo is designed to help readers understand how the orchestration of the same attention math evolves from FA1’s tiled online softmax baseline to FA4’s phased scheduler (main/softmax/correction) with selective rescaling.
  • Each version highlights specific algorithmic changes, including query-tile ownership and deferred normalization (FA2), staged pipelining with ping-pong buffers and a simplified educational FP8 forward path (FA3), and conditional rescaling mechanics (FA4).
  • The implementations explicitly avoid being hardware-faithful reproductions of official CUDA/Hopper/Blackwell kernels, focusing instead on exposing design ideas via readable code.
  • The author invites feedback on whether the code makes version-to-version differences intuitive, and links to the GitHub repo for review.

I recently updated my FlashAttention-PyTorch repo so it now includes educational implementations of FA1, FA2, FA3, and FA4 in plain PyTorch.

The main goal is to make the progression across versions easier to understand from code.

This is not meant to be an optimized kernel repo, and it is not a hardware-faithful recreation of the official implementations. The point is to expose the algorithmic ideas and design changes without immediately going deep into CUDA/Hopper/Blackwell-specific details.

Roughly, the repo now shows:

  • FA1: tiled online softmax baseline
  • FA2: split-Q / query-tile ownership, deferred normalization
  • FA3: explicit staged pipeline with ping-pong tile buffers, plus a simplified educational FP8 forward path
  • FA4: explicit scheduler with main / softmax / correction phases, and conditional/selective rescaling

So the same exact attention math is preserved, but the orchestration changes version by version.

I wrote it for people who want to understand:

"What actually changed from FA1 → FA2 → FA3 → FA4?""

without having to start from highly optimized CUDA kernels.

Repo: https://github.com/shreyansh26/FlashAttention-PyTorch

Would be interested in feedback on whether the code makes the version-to-version differences intuitive.

submitted by /u/shreyansh26
[link] [comments]