Minimising memory reads and writes between slow high bandwidth memory (HBM) and fast on-chip Static Random Access Memory (SRAM).
Standard attention is memory-bound (limited by memory bandwidth rather than compute power).
Flash attention reduces memory access complexity to , where N is sequence length, d is head dimension, and M is SRAM size.
On GPUs, performance is often limited not by how many FLOPs the GPU can compute with its compute resources, but by how often data must be read from or written to slower memory units. In other words, the bottleneck is usually memory access rather than computation.
Three components to optimise:
Compute: Time spent by the GPU computing floating point operations (FLOPS).
Memory: Time spent transferring tensors within a GPU.
Overhead: All other operations (Python interpreter, PyTorch dispatch,…).
Figure 2: Example of Mixture of Experts architecture for large models.