LLMs Deep Dive
Chapter 06Part II · Pretraining & Scale

Infrastructure, Distributed Training & Scaling

8 practice sets · 4 coding problems

Training a frontier language model is, before it is anything else, an exercise in moving numbers between chips fast enough to keep thousands of very expensive processors busy. A single modern GPU can do hundreds of teraFLOPs per second, but a frontier model has tens to hundreds of billions of parameters and is trained on trillions of tokens; no one chip holds the model, no one chip holds enough data, and no one chip is remotely fast enough on its own. The entire subject of this topic is how to spread one training run across many GPUs — splitting the model, the data, and the work — while paying as little as possible for the communication that splitting forces on you. This mini-chapter builds the vocabulary from the ground up, assuming only that you have seen a neural network train with gradient descent. By the end you should know what fills a GPU's memory and why it overflows, the handful of ways to cut a model into pieces, the small set of collective operations that glue the pieces back together, and the single number — utilization — that tells you whether your million-dollar cluster is actually doing useful work or mostly waiting around.

The big picture: why one GPU is never enough

It is tempting to think the problem with training big models is that they are slow — too many FLOPs to chew through. That is the second problem. The first problem, the one you hit before you ever think about speed, is that the model and its training scaffolding will not physically fit in a single GPU's memory. Two distinct walls force you to use many GPUs, and it is worth separating them because each parallelism strategy you will meet is aimed at one or the other.

The first is the memory wall. A GPU's fast working memory — its HBM (high-bandwidth memory, the DRAM stacked right next to the compute die, e.g. 8080 GB on an NVIDIA A100 or H100) — must hold, simultaneously, four different things: the model weights, the optimizer's bookkeeping, the gradients, and the activations (the intermediate vectors produced at every layer on the forward pass, which have to be kept around so the backward pass can use them). When the sum of those four buckets exceeds HBM capacity, the run dies instantly with an OOM (out-of-memory) error. The second wall is the throughput wall: even supposing the model fit, grinding through trillions of tokens on a single GPU would take years. You need to process many batches at once.

To make the memory wall concrete we have to count bytes carefully, because each parallelism trick targets a specific term. The modern recipe trains in mixed precision: most math runs in a low-precision 1616-bit format (fp16 or bf16, 22 bytes per number) for speed, but a high-precision fp32 “master” copy and the optimizer's statistics are kept in 3232-bit (44 bytes) so that tiny weight updates do not vanish into rounding error. Pair that with AdamW, the near-universal optimizer, which keeps two running averages per parameter — a first moment mm (the smoothed gradient) and a second moment vv (the smoothed squared gradient) — and the per-parameter storage adds up fast. The next box does the arithmetic.

The fourth bucket, activation memory, behaves completely differently from the first three. It does not scale with the parameter count PP; it scales with how much data you push through at once — batch size times sequence length — and with the number of layers, because every layer's inputs must be stashed for backprop. A useful per-layer estimate for a transformer block (the Megatron-LM result) is on the order of sbsh(34+5as/h)s_b\,s\,h\,(34 + 5\,a\,s/h) bytes, where sbs_b is the micro-batch size, ss the sequence length, hh the hidden size, and aa the number of attention heads; the key feature is the as/ha\,s/h term, which grows with the square of sequence length (it comes from the s×ss\times s attention scores). The upshot: at short context, weights dominate; at long context, activations can dwarf everything else and become the thing that OOMs you. Different walls, different cures.

Training memory is four buckets: weights, gradients, optimizer state, and activations. Mixed-precision AdamW costs 16\approx 16 bytes/param for the first three combined (22 weight ++ 22 grad ++ 1212 optimizer/master). Activations are a separate fourth bucket that grows with batch ×\times sequence length, not with PP. Every parallelism strategy you will meet exists to split one or more of these four buckets across GPUs.

Loading diagram…

Collectives: how GPUs talk

Before we can split a model, we need the communication primitives, because splitting always creates a need to recombine. GPUs in a training job coordinate through collective operations: structured exchanges in which all participating GPUs act together, as opposed to a point-to-point send/receive between just two of them. The library that implements these on NVIDIA hardware is NCCL (the NVIDIA Collective Communications Library, pronounced “nickel”); it is the plumbing under PyTorch's distributed backend. Each GPU is called a rank. Four collectives appear over and over:

  • Broadcast: one rank holds a value; afterward every rank has a copy of it. (Used to hand out initial weights.)
  • All-reduce: every rank starts with its own vector of the same shape; afterward every rank holds the element-wise sum (or average) of all of them. This is how data-parallel gradients get combined.
  • Reduce-scatter: like all-reduce, but each rank keeps only its own slice of the summed result instead of the whole thing.
  • All-gather: the inverse — every rank holds one slice; afterward every rank holds the full concatenation of all slices.

The two most important identities to carry forward: a reduce-scatter followed by an all-gather equals an all-reduce, and (the all-to-all variant) every rank sending a distinct chunk to every other rank is a full data transpose — the workhorse of expert parallelism for mixture-of-experts models. The picture below shows the four basic shapes on three ranks.

Loading diagram…

Why does any of this cost so much? Because the cost of a collective depends on where the GPUs are. A node is one physical server holding (typically) 88 GPUs wired together by NVLink, an ultra-fast intra-node interconnect (hundreds of GB/s to multiple TB/s). Crossing between nodes uses a slower fabric — InfiniBand or Ethernet — often several times less bandwidth and far higher latency. The smol-training-playbook's measurements make this vivid: an all-reduce that sustains 480\approx 480 GB/s of bus bandwidth inside one node drops to 320\approx 320350350 GB/s across nodes, and per-operation latency jumps from 13μ\sim 13\,\mus within a node to 55μ\sim 55\,\mus across two. This bandwidth-and-latency cliff at the node boundary is the single most important fact in parallelism design: you place the chattiest communication inside a node and the sparsest across nodes.

Loading diagram…

The standard efficient all-reduce is the ring all-reduce, and it is worth understanding once because its scaling property is why training to thousands of GPUs is possible at all. Arrange the NN GPUs in a logical ring and cut the MM-byte gradient vector into NN chunks. In a reduce-scatter phase, N1N-1 steps pass partial sums around the ring so each GPU ends owning one fully-summed chunk; an all-gather phase then circulates those finished chunks N1N-1 more steps so everyone has all of them. Each GPU sends (and receives) N1NM\tfrac{N-1}{N}M bytes per phase, 2M\approx 2M in total — and crucially that is independent of NN for large NN. Doubling the cluster does not double each GPU's communication bill; that flat per-GPU cost is what makes the ring scale.

Loading diagram…

The four (plus one) ways to split a model

There are four standard axes along which you can cut up a training run, plus one more for long context. Big runs combine several of them at once into a so-called 3D (or 4D) parallelism grid. Each axis splits a different bucket and demands a different collective; the trick is matching the noisy ones to NVLink and the quiet ones to InfiniBand.

Data parallelism (DP) is the simplest. Replicate the entire model on every GPU; give each replica a different slice of the batch; let each compute gradients on its own slice; then all-reduce the gradients so every replica averages them and they all take the identical optimizer step, staying bit-for-bit in lockstep. DP scales throughput beautifully — twice the GPUs, roughly twice the tokens per second — but it does nothing for the memory wall, because every GPU still holds a full 16P16P-byte copy of everything. DP attacks the throughput wall only.

Tensor parallelism (TP) splits individual weight matrices across GPUs. A single big matrix multiply Y=XWY = XW is cut column-wise or row-wise so each GPU computes a partial product, and an all-reduce (or all-gather) inside the layer stitches the pieces back into the correct output. TP genuinely shrinks the weights, gradients, and optimizer state on each GPU — but it pays for it with an all-reduce inside every layer, on every forward and backward pass. That is a torrent of traffic, which is exactly why TP is almost always confined to a single node so it can ride NVLink; stretch it across the slow inter-node fabric and the all-reduces stall the whole thing.

Pipeline parallelism (PP) splits the model by depth: GPU 00 holds layers 1188, GPU 11 holds 991616, and so on, like stations on an assembly line. A batch flows through the stages stage by stage. It needs only cheap point-to-point sends of activations between adjacent stages (no all-reduce), so it is happy to span nodes. Its weakness is the bubble, which we devote a picture to next.

Expert parallelism (EP) is specific to mixture-of-experts models, where each layer has many “expert” FFNs but each token uses only a few. EP spreads the experts across GPUs, and an all-to-all routes each token to whichever GPU holds its chosen expert — then a second all-to-all routes the results back. Because that is an all-to-all over the entire batch, on a slow fabric it can dominate step time, which is one reason MoE models report lower MFU.

The “plus one” is sequence (or context) parallelism (SP), which splits along the sequence-length dimension. It shares the activation memory and the work of the non-TP regions (LayerNorm, dropout, residual adds) that plain tensor parallelism leaves replicated on every GPU. SP is what you reach for when activations — not weights — are the thing that OOMs you, i.e. at very long context.

Loading diagram…

The pipeline bubble, and how micro-batching pops it

Pipeline parallelism's flaw deserves its own diagram because the fix — micro-batching — is the prototype for a recurring idea: keep every device busy by always having more small units of work in flight. When a single batch enters a pp-stage pipeline, stage 11 works on it first while stages 2,,p2,\dots,p sit idle with nothing to do; only once the batch reaches the last stage are all stages busy; then, as the pipeline drains, stages go idle again in reverse. Those triangles of idle time at the start (fill) and end (drain) are the bubble.

The cure is to chop the batch into mm small micro-batches and stream them in back-to-back, so that while stage 11 starts micro-batch 22, stage 22 is already working on micro-batch 11, and so on — the steady state has every stage busy on a different micro-batch. For pipeline depth pp and mm micro-batches, the fraction of time wasted in the bubble is

bubble fraction  =  p1m+p1, \text{bubble fraction} \;=\; \frac{p-1}{m+p-1},

which shrinks toward zero as mm grows large relative to pp. With p=4,m=4p=4,\,m=4 the bubble is 3743%\tfrac{3}{7}\approx 43\% wasted — terrible. With p=8,m=32p=8,\,m=32 it is 73918%\tfrac{7}{39}\approx 18\%. Smarter schedules (1F1B — one-forward-one-backward, and interleaved pipelines that give each GPU several non-contiguous layer chunks) interleave the forward and backward passes to shrink the bubble further at the same mm.

Loading diagram…

The gray cells are the bubble: the lower-left fill triangle (stages waiting for their first micro-batch) and the upper-right drain triangle (stages that have finished). More micro-batches make the busy middle wider relative to those fixed triangles, so a larger mm amortizes the bubble away — exactly what the p1m+p1\tfrac{p-1}{m+p-1} formula says.

Loading diagram…

ZeRO / FSDP: sharding the data-parallel replicas

Plain DP is wasteful: it replicates all 16P16P bytes on every GPU even though, at any instant, a GPU is only using one layer's worth. ZeRO (the Zero Redundancy Optimizer, from DeepSpeed) and its PyTorch-native equivalent FSDP (Fully Sharded Data Parallelism) fix this by sharding the model state across the NN data-parallel ranks instead of replicating it, in three escalating stages:

  • Stage 1 shards the optimizer state (that fat 1212-byte-per-param part: mm, vv, and the fp32 master weight). Each rank keeps only 1/N1/N of it. Roughly a 4×4\times memory cut, at the same communication volume as plain DP.
  • Stage 2 additionally shards the gradients. Roughly an 8×8\times cut, still at plain-DP communication volume.
  • Stage 3 (full FSDP) additionally shards the weights themselves, so each rank permanently stores only 1/N1/N of every bucket — memory falls roughly linearly in NN (split across 6464 GPUs, 64×\approx 64\times less per-GPU state). The price is about 1.5×1.5\times the communication of plain DP.
Loading diagram…

The mechanism that makes Stage 3 work is the pair of collectives from earlier. Just before a layer runs, an all-gather reconstructs that layer's full weights on the fly from the shards; the layer computes; the full weights are immediately discarded; and on the backward pass a reduce-scatter sends each gradient slice home to the rank that owns it. So a layer's full weights exist only momentarily, during that layer's compute. The reason this does not cripple throughput is overlap: while the current layer computes, the framework is already all-gathering the next layer's weights in the background, hiding the communication behind compute. When the overlap is clean, you get close to plain-DP speed at a fraction of the memory — which is why FSDP/ZeRO-3 is the default way to train models that would never fit otherwise.

Activation checkpointing: trading compute for memory

Even with weights sharded, the activation bucket can still OOM you, especially at long context. Activation checkpointing (a.k.a. gradient checkpointing or recomputation) is the standard escape valve. Instead of storing every layer's activations for the backward pass, you store only a few “checkpoint” activations and recompute the rest by re-running that segment's forward pass during backprop. It trades compute for memory: you pay roughly one extra forward pass (33%\sim 33\% more FLOPs overall, since backward already costs about twice forward) in exchange for a large drop in activation memory. It is one of the most-used knobs at scale — but over-applying it (recomputing layers that were cheap to just store) wastes FLOPs and silently drags down utilization, a classic “why did my throughput drop?” bug.

Is the cluster actually working? MFU and the roofline

You now have many ways to spend GPUs. How do you know whether they are spent well? The headline metric is MFU (Model FLOPs Utilization): the ratio of the useful FLOPs the model's math requires (computed from first principles — the standard transformer estimate is 6N\approx 6N FLOPs per token for an NN-parameter dense model) to the FLOPs the hardware could theoretically deliver in that same wall-clock time. An MFU of 30\mathbf{30}50%\mathbf{50\%} is considered good for large transformer pretraining: Meta reported 383841%41\% for Llama 3 405405B, DeepSeek-V3 landed around 202030%30\% (its MoE all-to-alls strain the inter-node fabric), and SmolLM3 targeted 30%\sim 30\%. Anything well below that range means GPUs are stalling — on communication, on data loading, on pipeline bubbles, or on a single slow straggler GPU that holds up every all-reduce (a collective only completes when its slowest participant does). Note that even a perfect kernel only reaches 70\sim 7077%77\% of peak on real matmuls, so the matmul ceiling, not 100%100\%, is the thing end-to-end MFU is chasing.

Why is high MFU so hard? The mental model is the roofline. Any computation is limited either by how fast the chip can do arithmetic (compute-bound) or by how fast it can move operands from HBM (memory-bound). The deciding quantity is arithmetic intensity — FLOPs performed per byte read from memory. Below a threshold (the “ridge,” at peak FLOPs ÷\div HBM bandwidth) you are memory-bound: the compute units sit idle waiting for data. Above it you are compute-bound and can actually approach peak. This one idea explains a recurring asymmetry: training and prefill (digesting a long prompt) do big, dense matrix multiplies with high intensity and are compute-bound, so they can reach high MFU; single-token decode during generation re-reads big weight matrices to do tiny matrix-vector products, has terrible intensity, is firmly memory-bound, and leaves most of the compute idle. Distribution, kernel fusion, and bigger batches all aim to push operations rightward, past the ridge, into the compute-bound regime.

Loading diagram…

The roofline has two segments: a rising line on the left where throughput is capped by memory bandwidth, and a flat roof on the right where it is capped by peak FLOPs. “Decode” sits on the rising part (idle compute); “prefill/train” sits up on the flat roof. Pushing a workload right — by batching more sequences together so each weight read is reused across more tokens — is how you raise intensity and climb toward the roof.

Scaling laws and fault tolerance: the run as a whole

Two ideas frame why all this machinery is worth building. Scaling laws are the empirical finding that test loss falls as a smooth, predictable power law in model size, dataset size, and compute — which is precisely why it pays to build ever-larger clusters: bigger reliably means better, and predictably so, so you can forecast the loss before you spend the money. The Chinchilla result sharpened this into a compute-optimal recipe — scale parameters and training tokens together (roughly 2020 tokens per parameter) — turning “how big a model should I train on my GPU budget?” into arithmetic.

Finally, a frontier run lasts weeks on thousands of GPUs, and at that scale hardware fails: the more chips, the shorter the mean time between failures (MTBF) for some chip somewhere. So training must be fault-tolerant. The basic tool is the checkpoint: periodically snapshot the full model and optimizer state to storage, so a crash costs only the work done since the last checkpoint rather than the entire run. Checkpoint too rarely and a failure wastes hours of compute; too often and you waste time writing terabytes to disk. There is an optimal cadence balancing those costs against the cluster's failure rate, with hot-spare nodes ready to swap in — one of the quintessential “infrastructure” calculations this topic asks you to actually do.

Putting it together / what to watch for

The unifying tension across this entire topic is compute versus communication. Every technique you have met either (a) makes a model fit that otherwise would not — TP, PP, EP, ZeRO/FSDP, activation checkpointing, sequence parallelism — or (b) keeps the GPUs fed so they reach high MFU — overlapping communication with compute, choosing the right parallelism degrees, killing pipeline bubbles, and chasing down stragglers and data-loading stalls. The interconnect hierarchy dictates the layout: chatty TP (and its in-layer all-reduce) goes inside a node on NVLink; sparse DP and PP span nodes over InfiniBand. The four memory buckets — weights, gradients, optimizer state, activations — tell you which parallelism to reach for, and the box arithmetic (1616 bytes/param, 1/N1/N under ZeRO-3) tells you whether your plan fits. MFU and the roofline tell you, afterward, whether it actually worked.

Keep this skeleton — four buckets, five axes, four collectives, the node-boundary bandwidth cliff, and utilization — and the detailed questions ahead will all read as calculations on a machine you already understand: deriving ZeRO's communication volume at each stage, computing a pipeline bubble fraction, working out when a MoE's all-to-all overtakes a tensor-parallel all-reduce, classifying decode-versus-prefill on a roofline, estimating per-GPU memory for a given parallelism config, and finding the checkpoint cadence that minimizes expected lost work.