Attention Efficiency & Long Context
8 practice sets · 5 coding problems
Topic 1 built the transformer and noted, almost in passing, two facts that this chapter is entirely about. First, self-attention costs in the sequence length : every token looks at every other token, so doubling the context quadruples the attention work. Second, generating text keeps a steadily growing KV cache of every past token's keys and values. For a short chat those two facts are footnotes. For a model asked to read a -page contract, a million-line codebase, or an hour of transcribed audio, they become the entire bill — in compute, in memory, and in the time you wait for the first word. “Attention efficiency and long context” is the engineering discipline of making that bill payable while keeping the answers good. This chapter assumes only Topic 1 (you know what , , , heads, softmax, RoPE, and the KV cache are). It builds a cost model first — where exactly the money goes — and then walks each family of fixes, so that every detailed question in the topic reads as a move in one budget.
Why long context is genuinely hard
It helps to separate the two pressures, because they are fought with completely different tools.
The first is compute. Attention compares each of queries against each of keys, producing a matrix of scores. That matrix has entries; at it is a million, at it is ten billion, per head, per layer. The arithmetic grows with the square of the context, so the part of the network that felt free at chat length quietly comes to dominate at document length.
The second is memory. To generate token , the model needs the keys and values of all earlier tokens (that is what attention attends to). Recomputing them every step would be madness, so we store them — and that store, the KV cache, grows by one entry per token, per layer, forever. At long context the cache, not the model's weights, is what overflows the GPU. This is the KV-cache memory wall.
Topic 1 already introduced the standard notebook-shrinkers: MQA (all query heads share one key/value head), GQA (heads share key/value heads in groups — the modern default, e.g. groups), and MLA (compress key/value into a small latent vector that is cached and re-expanded per head). We will not re-derive them; this chapter shows where in the budget each one bites, alongside the other levers.
Two phases: prefill and decode
Generation runs in two phases with opposite cost profiles, and almost every optimization targets one or the other.
Prefill processes the whole prompt at once. All prompt tokens go through the network in parallel; attention sees a full interaction; the work is one big matrix-multiply that keeps the hardware's arithmetic units busy. Prefill is a parallel sprint, and it is what sets your time-to-first-token.
Decode then emits new tokens one at a time. Each step runs the network for a single new token, produces one logit vector, samples, and appends. Decode is a long line of tiny, almost-serial steps — and it is where the KV cache is read over and over.
The bridge between the phases is that cache. During prefill we fill it once; during decode each step computes for only the one new token, appends its , and lets the new query attend over the whole cache. That turns per-step work from (recompute everything) into (attend over cached entries) — the cache buys back the quadratic, at the price of growing without bound.
The KV-cache budget — a worked number
It pays to size the cache exactly, because it drives nearly every long-context decision. The cache stores, per token, one vector and one vector in every layer. Multiply out the factors and the total in bytes is
Every symbol is a knob, and every cache-shrinking trick in this chapter is an attack on one of them: GQA/MQA/MLA shrink , quantization shrinks , sliding windows and eviction cap .
At short context the weights dominate memory; at long context the KV cache does, and it grows linearly with sequence length, layers, and KV-heads. The single inequality “cache weights” is why this whole topic exists — and every cache trick (MQA/GQA/MLA, quantization, eviction, sliding windows) is an attack on one factor of the KV-bytes formula.
Why memory, not FLOPs, is the bottleneck: the memory hierarchy
Here is a fact that surprises newcomers: at decode the GPU's powerful matrix units mostly sit idle. Why? Because moving data is far slower than computing on it, and decode moves a lot of data per unit of arithmetic.
A GPU has a steep memory hierarchy. At the bottom is HBM (high-bandwidth memory) — the tens of gigabytes where weights and the KV cache live. It is large but, relative to the compute units, slow: a few terabytes per second. At the top is SRAM — a tiny on-chip scratchpad (kilobytes per processor) that is roughly an order of magnitude faster but far too small to hold a whole tensor. Every number a compute unit touches must be pulled up from HBM into SRAM, used, and the result pushed back down.
The relevant ratio is arithmetic intensity: how many floating-point operations you do per byte you move from HBM, . The roofline model says attainable throughput is . Below a hardware “ridge point” you are memory-bound (the expensive matrix units wait on data); above it you are compute-bound (you finally saturate them). Decode attention sits far on the memory-bound side: each cached entry is read once and used for just a couple of multiply-adds, so intensity is roughly , and wall-clock time is set by how fast you can stream the cache, not by FLOPs. This is why a smaller cache is the prize: fewer bytes to stream means a faster token.
This also explains a lever that looks like free money: batching raises intensity. The weights are read once from HBM but reused across all sequences in the batch, so more FLOPs ride on the same bytes — batching pushes decode rightward toward the compute roof and lifts throughput, right up until the per-sequence caches stop fitting. The tell-tale that you have hit the KV wall is that decode throughput stops scaling with batch size and flatlines: you are now bandwidth-bound on the cache, and the cure is to make the cache smaller, not to add compute.
FlashAttention: exact attention without the score matrix
The other half of the bill is the score matrix in prefill. Materializing it is fatal at scale: at , one head's score matrix is entries bytes GB — per head, per layer. The textbook algorithm writes that whole matrix to HBM, reads it back to apply softmax, then reads it a third time to multiply by : three round-trips of an object through slow memory.
FlashAttention computes the exact same answer without ever writing to HBM. Two things to be crystal clear about, because they are common exam traps. It is not an approximation — the output equals standard attention up to floating-point rounding. And it does not cut FLOPs — it still does multiply-adds. What it slashes is memory I/O; it is “IO-aware,” and on memory-bound hardware that is exactly the bottleneck that matters.
The mechanism is tiling plus online softmax. Tiling: load a block of queries and a block of keys/values into fast SRAM, compute that block's partial scores and partial output there, accumulate into a running result, then move to the next block. The big matrix is born and consumed on-chip and never lands in HBM. The obstacle is that softmax needs a global normalizer — the sum over the whole row — which you cannot know until you have seen every key. Online softmax fixes this by carrying, per query, just three running numbers and rescaling them as each new block arrives.
Concretely, recall the numerically stable softmax subtracts the row max before exponentiating (so never overflows). FlashAttention keeps, per query: a running max , a running denominator , and a running output accumulator . When a new block reveals a larger max, every previously accumulated quantity was scaled by the old max and must be corrected. With new block max , set , multiply the old and old by the correction factor , then add the new block's freshly-scaled terms. After the last block, the output is — algebraically identical to softmaxing the full row, but computed in memory instead of .
Successive versions refined only the engineering, never the result. FlashAttention-2 re-partitioned the GPU work for far higher utilization. FlashAttention-3 exploits the asynchrony of Hopper-class GPUs — overlapping data movement with matmul and softmax via warp specialization — and adds FP8 support, reaching roughly of the H100's peak ( TFLOP/s in fp16), up from the of FlashAttention-2. All three are exact.
Bounding the quadratic: windows, sinks, sparsity
FlashAttention makes the cheaper to run; it does not make it go away. To break quadratic scaling outright you must compute attention over fewer pairs.
The simplest cut is sliding-window (local) attention, popularized at scale by Mistral: let each token attend only to the previous tokens. Cost drops to — linear in — and the cache caps at entries instead of growing forever. The bet is that language is mostly local: most of what you need to predict the next word sits nearby. Stack many layers, though, and information still propagates far, because a token's window overlaps its neighbor's, so influence travels a window per layer (much like a convolution's receptive field).
Pure windows have a sharp failure mode: drop the oldest tokens and you lose access to the start of the sequence (the system prompt, the question being answered). Worse, naively evicting the first tokens tends to crash quality outright. The reason is a real and slightly weird phenomenon: the first few tokens act as an attention sink. Because softmax weights must sum to one, a query that finds nothing especially relevant still has to put its probability mass somewhere; heads learn to dump that excess onto the always-visible, low-content opening tokens. The sink is a pressure-release valve. StreamingLLM turns this into a method: keep a few initial sink tokens plus a sliding window, evict the middle, and a model can stream indefinitely without the collapse that removing the sinks causes.
More general sparse attention keeps the softmax but lets each token attend to a structured subset of positions — typically dense for nearby tokens and sparse for distant ones (some strided or block-strided pattern), often plus a handful of global tokens that every position can see and that can see everything (Longformer-style). The intuition is that genuinely dense long-range dependencies are rare, so “dense locally, sparse globally” loses little while cutting cost below quadratic. The catch is that the pattern is hand-designed and can miss the one long-range link a task happens to need; choosing it well is the whole art.
Linear / kernelized attention: trading exactness for
A bolder move removes the softmax entirely. Standard attention is forced to compute first — an object — because softmax acts on it nonlinearly. Linear attention replaces softmax with a similarity that factorizes, for some feature map . Then the numerator rearranges by simple associativity:
The left side builds the matrix first (); the right side computes the small matrix first, then multiplies by (). For long sequences () that is the difference between quadratic and linear in .
The catch is quality. A fixed-size state cannot losslessly store arbitrarily long history — it is a lossy summary — so linear-attention and state-space models (Mamba-style) can stumble on exact long-range retrieval and copying, the tasks where you must reproduce a specific distant token verbatim. That is precisely why hybrid designs are popular: make most layers cheap (linear or windowed) and sprinkle in a few full-attention layers where verbatim recall matters — cheap state for the bulk, exact attention for the spots that need it. As a rule of thumb: for retrieval-heavy work, exact methods win (FlashAttention, plus ring attention, which shards one long sequence across devices and passes K/V blocks around a ring so no single device holds it all); for workloads that tolerate lossy memory, approximate methods win on cost.
Teaching a short-context model to use long context: RoPE extension
Suppose your model was pretrained with a k window and you now feed it k tokens. It breaks — but, perhaps surprisingly, not because of the cache (a bigger GPU fixes that). It breaks because of positional encoding. Recall RoPE rotates each query and key by an angle proportional to its position, with per-dimension frequencies
so the attention score depends only on the relative offset between two tokens. Feed positions far past anything seen in training and the rotation angles wind into a regime the model has never encountered — the positions are out-of-distribution, and output quality collapses. The fixes all reshape those angles so they stay in a familiar range.
Position Interpolation (PI) is the bluntest: to go from trained length to target , define the scale and treat position as . This squeezes all new positions back into the trained range, so every angle is in-distribution. It works, but it compresses everything uniformly, blurring fine distinctions between nearby tokens (the model can less easily tell “one apart” from “two apart”).
NTK-aware scaling is subtler. Instead of squeezing all frequencies equally, it scales the RoPE base,
which interpolates the slow, low-frequency dimensions (the ones that track coarse, long-range position) while leaving the fast, high-frequency dimensions (which distinguish adjacent tokens) nearly untouched. You keep local resolution and only stretch the global scale — exactly where there is slack.
YaRN combines the ideas and adds one more. It partitions RoPE dimensions by how many full cycles each completes over the training length: high-frequency dims (many cycles, local detail) are left to extrapolate, low-frequency dims (few cycles, global position) are interpolated, and a middle band is smoothly ramped between the two — this is the “NTK-by-parts” interpolation. On top, YaRN rescales the attention logits by a small temperature factor to counteract the entropy drift that longer contexts otherwise induce in the softmax. The payoff is striking: YaRN can take a model from k to k context by fine-tuning on well under of the original pretraining tokens.
Lost in the middle: attention is not retrieval
One last reality check before the questions: a long context window is not the same as using it well. Two pitfalls recur.
First, perplexity is not the test. A context-extended model can post fine perplexity yet fail needle-in-a-haystack retrieval past some length. Perplexity averages over mostly-easy nearby tokens; retrieval probes the single hard long-range dependency that perplexity barely weights. So always validate long context with a retrieval probe — and re-check that short-context scores did not quietly regress, since extension can degrade them.
Second, lost-in-the-middle. Models reliably use facts placed at the start and end of a long prompt but neglect the middle — accuracy as a function of where the answer sits is U-shaped. This is a position bias (from training-data structure and attention dynamics, reinforced by the attention sinks at the front), not a hard capacity limit, which is why it shows up even when the model technically “saw” the middle. It is also why retrieval-augmented pipelines bother to rank and place the most relevant chunks at the edges of the context rather than dumping everything in arbitrary order.
What to watch for
The practical loop for serving long context is now fairly standard, and it is just the budget applied in order. Pick an exact kernel (FlashAttention) so attention is never the memory bottleneck. Shrink the cache in order of cheapness: GQA first, then KV quantization (8- or 4-bit; cheap, but watch long-range copying, since a slightly corrupted key flips which token is retrieved — degraded needle-retrieval with unchanged MMLU is its signature), then sliding-window or heavy-hitter eviction, then MLA if you control the architecture. Use chunked prefill (process a giant prompt in slices) to cap time-to-first-token. Extend positions with YaRN-style scaling, and validate with retrieval, not perplexity, re-checking short context for regressions. Hold the two anchors — the KV-bytes formula and the roofline — in your head, and every question in this topic, from tile-size derivations to MLA's decoupled-RoPE trick to attention-sink eviction, becomes a single, recognizable move in the same budget.
