Ragged Paged Attention: A High-Performance and Flexible LLM Inference Kernel for TPU

arXiv cs.AI / 4/20/2026

📰 NewsDeveloper Stack & InfrastructureTools & Practical UsageModels & Research

Key Points

  • The paper introduces Ragged Paged Attention (RPA), a TPU-focused attention kernel designed to make LLM inference efficient on architectures where serving workloads are dynamic and “ragged.”
  • RPA improves performance and flexibility using fine-grained tiling for efficient dynamic slicing, a fused pipeline that combines KV-cache updates with attention computation, and a compilation strategy that generates specialized kernels for decode, prefill, and mixed workloads.
  • Experiments on Llama 3 8B running on TPU7x show strong utilization metrics, reaching up to 86% memory bandwidth utilization during decode and 73% model FLOPs utilization during prefill.
  • The work is implemented with Pallas and Mosaic and has been integrated as the primary TPU backend in vLLM and SGLang, aiming to provide a production-ready foundation for TPU inference kernel design.

Abstract

Large Language Model (LLM) deployment is increasingly shifting to cost-efficient accelerators like Google's Tensor Processing Units (TPUs), prioritizing both performance and total cost of ownership (TCO). However, existing LLM inference kernels and serving systems remain largely GPU-centric, and there is no well-established approach for efficiently mapping LLM workloads onto TPU architectures--particularly under the dynamic and ragged execution patterns common in modern serving. In this paper, we present Ragged Paged Attention (RPA), a high-performance and flexible attention kernel for TPUs, implemented using Pallas and Mosaic. RPA addresses these challenges through three key techniques: (1) fine-grained tiling to enable efficient dynamic slicing over ragged memory, (2) a custom software pipeline that fuses KV cache updates with attention computation, and (3) a distribution-aware compilation strategy that generates specialized kernels for decode, prefill, and mixed workloads. Evaluated on Llama 3 8B on TPU7x, RPA achieves up to 86% memory bandwidth utilization (MBU) in decode and 73% model FLOPs utilization (MFU) in prefill. Integrated as the primary TPU backend in vLLM and SGLang, RPA provides a production-grade foundation for efficient TPU inference and offers practical insights into kernel design.