LLMs Deep Dive
Chapter 05Part II · Pretraining & Scale

Optimization & Training Dynamics

8 practice sets · 5 coding problems

The big picture: from a loss to a trained model

Pretraining hands us exactly one number to push down — the cross-entropy loss L(θ)\mathcal{L}(\theta), the model's average surprise at the next token, measured over billions of tokens. What it does not hand us is any instruction for how to move the hundreds of billions of parameters θ\theta to make that number smaller, nor how to do it without the run exploding halfway through. That “how” is the subject of this chapter: optimization and training dynamics.

The thing we are descending is not a tidy bowl. The loss of a large transformer is a high-dimensional, non-convex, bumpy surface, and we never see it directly — we only ever probe it through small random mini-batches of data, so every gradient we compute is a noisy estimate of the true downhill direction. The optimizer's job is to turn that stream of noisy gradients into a sequence of parameter updates that descends reliably, quickly, and above all stably. The stakes are real: a frontier run can cost millions of dollars and burn weeks of wall-clock time, and a single divergence at step 40k40\text{k} can throw all of it away. By the end of this chapter you will be able to read a loss curve like a practitioner and answer the eternal questions: is this spike normal, and what knob do I turn?

We build the machinery in dependency order: plain gradient descent and why it struggles, then the two fixes (momentum and adaptive step sizes) that combine into Adam and AdamW; the learning-rate schedule (warmup, cosine, and WSD); batch size and gradient accumulation; gradient clipping and weight decay; numerical precision (bf16/fp16/fp8 and loss scaling); the instabilities that bite real runs and how to fight them; and finally a glimpse of the newer Muon optimizer.

Gradient descent, and why plain SGD struggles

Write the loss as a function of the full parameter vector θRP\theta\in\mathbb{R}^{P}. Its gradient g=θLg=\nabla_\theta\mathcal{L} is the vector of partial derivatives; it points in the direction of steepest increase, i.e. straight uphill. So the simplest possible learning rule, gradient descent, steps the opposite way — downhill:

θt+1=θtηgt, \theta_{t+1}=\theta_t-\eta\, g_t ,

where η>0\eta>0 is the learning rate, the step size, the single most important number you will set. The gradient tells you which way to go; η\eta tells you how far to trust that direction before you look again.

We cannot afford the gradient over the entire corpus for every step, so we estimate it from a random mini-batch: g^t\hat g_t. Stepping with this estimate is stochastic gradient descent (SGD). The estimate is unbiased (on average it equals the true gradient) but noisy (any single batch wobbles around the truth). That noise is not purely a nuisance — a little jitter helps the optimizer rattle out of sharp, overfit minima — but it does mean the raw step direction shakes from step to step.

Plain SGD has two deeper problems that motivate everything that follows.

Loading diagram…

Fix 1 — Momentum: average, then step

Momentum attacks the zig-zag by stepping along a running average of past gradients instead of the single latest one. With a decay coefficient β[0,1)\beta\in[0,1),

mt=βmt1+(1β)g^t,θt+1=θtηmt. m_t=\beta\,m_{t-1}+(1-\beta)\,\hat g_t,\qquad \theta_{t+1}=\theta_t-\eta\,m_t .

The vector mtm_t is an exponential moving average (EMA) of the gradients: each new gradient is mixed in with weight 1β1-\beta, and older ones fade geometrically. The EMA remembers roughly the last 1/(1β)1/(1-\beta) gradients — so β=0.9\beta=0.9 averages over about 1010 steps, β=0.95\beta=0.95 over about 2020.

Why this helps: directions in which the gradient consistently points the same way (down the valley floor) accumulate and accelerate, like a ball gathering speed; directions that flip sign every step (bouncing across the walls) cancel out in the average and damp down. Same data, far smoother trajectory. This is the first key idea — average, then step.

Fix 2 — Adaptive step sizes: RMSProp

Momentum still uses one global η\eta for every parameter. But different weights in a transformer see gradients of wildly different magnitude — an embedding row for a rare token, a gain in a norm layer, and a weight deep in an FFN are not on the same scale at all. A single η\eta is a blunt instrument: too big for the loud parameters, too small for the quiet ones.

RMSProp gives every parameter its own effective step size by tracking, per coordinate, an EMA of the squared gradient — a running estimate of how big that coordinate's gradients typically are:

vt=β2vt1+(1β2)g^t2,θt+1=θtηg^tvt+ε. v_t=\beta_2\,v_{t-1}+(1-\beta_2)\,\hat g_t^{2},\qquad \theta_{t+1}=\theta_t-\eta\,\frac{\hat g_t}{\sqrt{v_t}+\varepsilon} .

Dividing the gradient by vt\sqrt{v_t} (its own typical magnitude) means a coordinate that habitually sees huge gradients gets its step shrunk, and a coordinate that sees tiny gradients gets its step amplified — both land near the same size. This is precisely the cure for the steep-vs-shallow curvature mismatch from the intuition box. The ε108\varepsilon\approx 10^{-8} only prevents division by zero.

Putting them together: Adam

Adam (Adaptive Moment Estimation) is, almost literally, momentum and RMSProp bolted together: it keeps both EMAs at once — the first moment mtm_t (smoothed gradient, the “which way” from momentum) and the second moment vtv_t (smoothed squared gradient, the “how big” from RMSProp):

mt=β1mt1+(1β1)g^t,vt=β2vt1+(1β2)g^t2,\begin{align*} m_t &= \beta_1\,m_{t-1}+(1-\beta_1)\,\hat g_t , \\ v_t &= \beta_2\,v_{t-1}+(1-\beta_2)\,\hat g_t^{2} , \end{align*}

with the squaring element-wise and the now-standard LLM defaults β1=0.9\beta_1=0.9, β2=0.95\beta_2=0.95 (early papers used 0.9990.999; modern large runs prefer the shorter 0.950.95 horizon). The update divides the smoothed gradient by the root of the smoothed squared gradient:

θt+1=θtηm^tv^t+ε. \theta_{t+1}=\theta_t-\eta\,\frac{\hat m_t}{\sqrt{\hat v_t}+\varepsilon} .

The hats denote bias correction (next paragraph). The magic is in the ratio m^t/v^t\hat m_t/\sqrt{\hat v_t}: it is dimensionless, and for a roughly steady gradient it has a root-mean-square magnitude near 11. To see why: if a coordinate's gradient is stable, then m^g\hat m\approx g and v^g\sqrt{\hat v}\approx |g|, so the ratio is ±1\approx\pm 1 and the actual step is ±η\approx\pm\eta regardless of how large or small gg is. Adam thus auto-scales every parameter to a step of size η\approx\eta, which is exactly why a single global learning rate can govern a wildly heterogeneous network.

Bias correction. The EMAs are initialized at m0=v0=0m_0=v_0=0, so early on they are biased toward zero — they have not yet “warmed up.” At step 11 with β1=0.9\beta_1=0.9 and g^=1\hat g=1, the raw m1=0.1m_1=0.1: ten times too small. Adam undoes this by dividing by 1βt1-\beta^{t}:

m^t=mt1β1t,v^t=vt1β2t. \hat m_t=\frac{m_t}{1-\beta_1^{t}},\qquad \hat v_t=\frac{v_t}{1-\beta_2^{t}} .

At step 11, 1β11=0.11-\beta_1^1=0.1, so m^1=m1/0.1=1\hat m_1=m_1/0.1=1 — corrected. Because β2\beta_2 is close to 11, the v^\hat v correction stays meaningful for many steps; without it, vv is badly underestimated early, making v^\sqrt{\hat v} tiny and the first few steps explosively large — a classic source of immediate divergence.

Loading diagram…

AdamW: why “decoupled” weight decay matters

We almost always want to gently pull parameters toward zero — weight decay, a regularizer that keeps weights from growing without bound and that, in practice, measurably helps generalization. The old way (“L2 regularization”) folds it into the gradient as g+λθg+\lambda\theta. With plain SGD that is fine. With Adam it quietly breaks: that extra λθ\lambda\theta then passes through the v^\sqrt{\hat v} normalization, so a parameter with large gradients gets its decay shrunk and one with tiny gradients gets its decay amplified. The regularization strength becomes entangled with gradient magnitude — exactly what we did not want.

AdamW fixes this by decoupling the decay: apply the adaptive update as usual, then subtract a clean ηλθ\eta\lambda\theta as a separate term that never touches v^\sqrt{\hat v}:

θt+1=θtη(m^tv^t+ε+λθt). \theta_{t+1}=\theta_t-\eta\Big(\frac{\hat m_t}{\sqrt{\hat v_t}+\varepsilon}+\lambda\,\theta_t\Big).

“Decoupled” is the whole point: every parameter now decays by the same relative fraction ηλ\eta\lambda per step, independent of its gradient history. This clean separation is why AdamW is the default optimizer for essentially every large LLM, with a strikingly stable recipe that has barely moved from Llama through DeepSeek: β1=0.9\beta_1=0.9, β2=0.95\beta_2=0.95, gradient-norm clip =1.0=1.0, weight decay λ=0.1\lambda=0.1 (a few very large runs drop it to 0.010.01).

One cost to remember: AdamW stores mm and vv alongside the weights, so the optimizer state is 2×2\times the parameter count. Kept in fp32, that is often the single largest memory term in a run — which is why distributed sharding (Topic 6) exists largely to spread it out.

AdamW's update is (smoothed gradient) / (root smoothed squared gradient) + decoupled decay. The ratio self-normalizes to RMS 1\approx 1, so one global learning rate gives every parameter a sensibly sized step. Decoupling the decay keeps regularization independent of gradient magnitude, and bias correction matters most early, for the slow (β2\beta_2) second moment.

The learning-rate schedule: warmup, then decay

A constant η\eta is rarely best, because the right step size changes as training proceeds: large early steps help while you are far from any good solution, but the same steps cause instability near convergence. So we vary η\eta on a fixed schedule with two or three phases.

Warmup. For the first few thousand steps, ramp η\eta linearly from 00 up to its peak. This is not superstition — it directly counteracts Adam's cold start. Early on, v^\hat v has seen only a handful of noisy samples, so it is an unreliable, high-variance estimate; taking a full-size step while dividing by an untrustworthy v^\sqrt{\hat v} is a recipe for an instant loss spike. Warmup lets the second moment accumulate enough samples to be trustworthy before we trust it with a big step — we tiptoe in. A common default is a fixed 2000\sim 2000 warmup steps regardless of model size; for short runs, people instead use 115%5\% of total steps.

Cosine decay. After the peak, anneal η\eta smoothly down to a small floor following a half-cosine:

ηt=ηmin+12(ηmaxηmin)(1+cosπ(ttwarm)ttotaltwarm). \eta_t=\eta_{\min}+\tfrac12(\eta_{\max}-\eta_{\min}) \left(1+\cos\frac{\pi(t-t_{\text{warm}})}{t_{\text{total}}-t_{\text{warm}}}\right).

Large early steps explore; small late steps settle into a minimum and quiet the gradient-noise jitter near convergence. A typical 77B pretrain warms up over 2000\sim 2000 steps to a peak ηmax3×104\eta_{\max}\approx 3\times10^{-4}, then cosine-decays to about 10%10\% of peak.

Warmup-Stable-Decay (WSD). Cosine has one annoying flaw: its curve must be fitted to the total step count up front. If you later get more compute and want to keep training, or you run scaling-law experiments at several token budgets, the cosine shape is wrong and you would have to restart. WSD avoids this. After warmup it holds η\eta constant at the peak for most of training (the “stable” phase), then sharply decays only in the final 101020%20\% of tokens. The constant phase has no baked-in endpoint, so you can train as long as you like and trigger the decay whenever you decide to stop — and you can branch off a quick decayed checkpoint at any point to read off “progress so far.” Studies show WSD matches cosine's final loss while being far more practical; it has become a popular modern choice (e.g. SmolLM3). DeepSeek's Multi-Step schedule is a cousin: hold constant, then drop η\eta in a couple of discrete steps (e.g. at 80%80\% and 90%90\%).

Loading diagram…

Batch size, gradient accumulation, and the critical batch size

The batch size is how many tokens we average a gradient over before each update. It is a throughput lever: bigger batches keep more GPUs busy. It also changes the quality of each gradient — averaging over more samples cancels more noise, so a larger batch is a sharper estimate of the true gradient, which means you can safely take a bigger step.

Often the batch you want will not fit in memory. Gradient accumulation is the trick: run several smaller “micro-batches” through forward/backward, sum their gradients without updating, and only step the optimizer after the last one. Accumulating kk micro-batches of size bb is numerically identical to one batch of size kbkb — you trade time for memory and recover any effective batch size you want on fixed hardware.

Loading diagram…

Gradient clipping: a seatbelt against cliffs

Even with a good schedule, an unlucky batch — a rare token pattern, a corrupted data shard — can produce a freakishly large gradient that would hurl the weights off a cliff in a single step. Global-norm gradient clipping is the seatbelt. Compute the norm of the entire gradient vector (all parameters concatenated), g=igi2\|g\|=\sqrt{\sum_i g_i^2}, and if it exceeds a threshold cc, rescale the whole gradient by c/gc/\|g\| so its norm becomes exactly cc:

ggmin ⁣(1, cg). g \leftarrow g\cdot\min\!\Big(1,\ \frac{c}{\|g\|}\Big).

Crucially it rescales globally — the direction is untouched, only the overall magnitude is capped, so the relative sizes of all components are preserved. Example: with c=1.0c=1.0 and a measured g=4.0\|g\|=4.0, the scale factor is 0.250.25, so a component equal to 3.03.0 becomes 0.750.75. The near universal default is c=1.0c=1.0, and clipping is the single most common first line of defense against loss spikes.

Loading diagram…

Numerical precision: bf16, fp16, fp8, and where things break

Modern accelerators are dramatically faster in low precision — an H100 does roughly 4×4\times more FLOPs in bf16 than in fp32, and more again in fp8 — so large runs do the heavy matrix multiplies in 1616- or 88-bit. But fewer bits means a coarser, narrower set of representable numbers, and that is where things break. A floating-point format splits its bits into an exponent (which sets the dynamic range — how big and how small a number it can reach) and a mantissa (which sets the precision — how finely it resolves values in between). The formats that matter:

  • fp32 (1 sign / 8 exp / 23 mantissa): the reference. Wide range, fine precision.
  • tf32: a Tensor-Core mode that keeps fp32's 8 exponent bits but truncates the mantissa to 1010 bits — nearly fp32's range at higher speed; PyTorch can use it automatically for fp32 ops.
  • fp16 (1 / 5 / 10): ten mantissa bits (fine precision) but only five exponent bits, so its dynamic range is narrow — small gradients underflow to zero and large values overflow to inf easily.
  • bf16 (1 / 8 / 7): keeps fp32's full 88-bit exponent (same wide range) and pays for it with only 77 mantissa bits (coarser precision).
  • fp8: two variants. e4m3 (44 exp / 33 mantissa, max ±448\approx\pm 448) for the forward pass, where a little extra precision helps; e5m2 (55 exp / 22 mantissa, max ±57344\approx\pm 57\,344) for gradients, where the wider range matters more.

bf16 beats fp16 for training because range matters more than precision. Deep nets produce gradients spanning many orders of magnitude; bf16's full fp32 exponent represents them all without under/overflow, at the cost of a coarser mantissa — a trade that training tolerates well. fp16's narrow range is the opposite trade and the reason it is fragile. The community has essentially stopped training in fp16; bf16 is the default.

Loading diagram…

Two safeguards keep low precision honest. Loss scaling (needed for fp16, rarely for bf16): multiply the loss by a big constant SS before backprop so all gradients shift up into the representable range, then divide them by SS before the optimizer step; if an overflow is detected, skip that step and halve SS. bf16's wide range usually makes this unnecessary. And critically, keep a master copy of the weights and the optimizer state in fp32. The slow weight updates (η×\eta\times a tiny ratio) and the m,vm,v accumulators are precisely the quantities that lose all their information if rounded to bf16 — demote them to save memory and you invite a slow, subtle destabilization. The fp8 rule (DeepSeek-V3 style) is the same in spirit: do the bulk matmuls in fp8 for throughput, but keep accumulations, master weights, optimizer moments, and sensitive ops (norms, softmax, residual adds) in higher precision.

Training instabilities: reading and taming the loss curve

Sooner or later a real run will misbehave. The most visible symptom is a loss spike: a sudden jump in the loss curve, sometimes a small blip, sometimes a vertical cliff that never recovers. The art is reading the shape.

Loading diagram…

A spike that recovers on its own within a few steps (as drawn) is usually benign — clipping absorbed a bad batch and the run healed itself; let it ride. A spike that persists or keeps growing signals real instability, and the response escalates in severity:

  1. Skip the offending micro-batch and continue — cheapest, often enough.
  2. Roll back to the last good checkpoint and skip the data shards that triggered it.
  3. Lower the peak LR (or tighten the clip threshold) and resume from the rollback.

A separate warning sign is a gradient norm that climbs slowly over thousands of steps even while the loss still looks fine; watch it alongside activation and parameter norms, because it often precedes a spike.

Beyond clipping and LR, a few targeted tricks address specific failure modes:

  • z-loss adds a tiny penalty λ(logZ)2\lambda\,(\log Z)^2 on the log of the softmax normalizer Z=ielogitiZ=\sum_i e^{\text{logit}_i}, which keeps the output logits from drifting to large magnitudes — exactly the drift that causes softmax/precision blowups. Cheap insurance, used in PaLM, OLMo-2, Qwen3.
  • QK-norm normalizes the query and key vectors before their dot product, capping the attention logits so the attention softmax cannot saturate or overflow — a common stabilizer for the largest models (with a known downside on very long context).
  • Removing weight decay from the embeddings lowers embedding norms and improves stability in some recipes (OLMo-2), without hurting quality.

A recent development: the Muon optimizer

AdamW treats every parameter independently — it is blind to the fact that a weight matrix WRm×nW\in\mathbb{R}^{m\times n} has 2D geometric structure. Muon exploits that structure for the matrix parameters. Its idea, in three lines: keep a plain momentum matrix BB of the gradients, then before stepping, orthogonalize it — replace BB with the nearest matrix whose singular values are all 11 (the “matrix sign,” UVUV^\top from the SVD B=UΣVB=U\Sigma V^\top). Computing an SVD every step is far too expensive, so Muon approximates the orthogonalization with about 55 iterations of a Newton–Schulz recursion — a fixed quintic polynomial in the matrix, XaX+bX(XX)+cX(XX)2X\leftarrow aX+bX(X^\top X)+cX(X^\top X)^2 with tuned coefficients (a,b,c)(3.44,4.78,2.03)(a,b,c)\approx (3.44,-4.78,2.03) — which drives all the singular values toward 11 without ever factorizing.

Loading diagram…

What to watch for

  • Most early instabilities are warmup or precision problems. A fresh run that diverges in the first few thousand steps almost always means (in order): too-short warmup, too-high peak LR, missing bias correction, or optimizer state / master weights demoted out of fp32.
  • Read the spike, not just the loss. A self-healing spike is usually fine; a persistent or growing one, or a steadily climbing gradient norm, calls for skip \to rollback \to lower-LR, in that order.
  • The optimizer is half the memory. AdamW's m,vm,v double the parameter footprint in fp32; this — not just the weights — is what FSDP/ZeRO sharding (Topic 6) exists to distribute.
  • Non-determinism is normal at scale. Two “identical” runs on different GPU counts can diverge slightly because the order and precision of the all-reduce summations differ; small loss gaps are expected, not a bug.
  • A safe starting point for a 7\sim 7B pretrain: AdamW with β=(0.9,0.95)\beta=(0.9,0.95), clip 1.01.0, weight decay 0.10.1, 2000\sim 2000 warmup steps, peak η3×104\eta\approx 3\times10^{-4} decayed (cosine or WSD) to 10%\sim 10\%, all in bf16 with fp32 master weights and optimizer state. Sanity-check it cheaply on a small proxy model before committing the full run.