Teju's Blog

Full stack engineer and AI architect. Notes from the work.


Gradient descent: how a language model learns anything

Gradient descent is the engine behind every neural network you have ever used. Cauchy described it in 1847. The algorithm that trains modern language models is barely different from the version he wrote down; it just runs on a few hundred billion more numbers.

The setup

A model has parameters; call them θ\theta. For a language model, θ\theta is the full set of weights (billions of numbers). The model takes inputs, produces outputs, and the outputs get compared against the correct answers by a loss function L(θ)L(\theta):

L(θ)=1Ni=1N(fθ(xi),yi)L(\theta) = \frac{1}{N} \sum_{i=1}^{N} \ell(f_\theta(x_i), y_i)

where {(xi,yi)}\{(x_i, y_i)\} is the training data, fθf_\theta is the model, and \ell is a per-example loss like cross-entropy. The goal is to find the θ\theta that minimises L(θ)L(\theta).

This is an optimisation problem. Gradient descent is the most boring, most effective way to solve it.

The basic idea

The gradient L(θ)\nabla L(\theta) is a vector. It has one component per parameter. It points in the direction of steepest ascent: the direction in which LL increases fastest. To minimise LL, you take a step in the opposite direction:

θnew=θηL(θ)\theta_{\text{new}} = \theta - \eta \nabla L(\theta)

η\eta is the learning rate, a small positive number. Repeat until LL stops decreasing. That is the entire algorithm.

repeat Current θ Forward pass:compute L Backward pass:compute ∇L Update:θ = θ - η ∇L

The forward pass computes LL. The backward pass uses the chain rule (backprop) to compute L\nabla L with respect to every parameter, going from the output back through every layer. Frameworks like PyTorch do this for you. You write the model, call loss.backward(), and PyTorch fills in L\nabla L on every parameter.

A picture

In one dimension, L(θ)L(\theta) is a curve. The gradient is the slope. Gradient descent rolls a ball downhill, step by step, until it lands at a minimum. Click step to take one update; the dashed line is the tangent (the slope the model just used).

A simple parabolic loss. The ball starts on the left, and each step pushes it toward the bottom in proportion to the slope at its current position.

The shape of a real loss curve is messier than this: there are spikes, plateaus, occasional climbs. The trend is what you watch.

Learning rate: the most-tuned hyperparameter

The learning rate η\eta is a step size. Too small and training crawls. Too large and the updates overshoot the minimum and the loss explodes.

Three behaviours, in pseudocode terms:

If η\eta is too small: loss decreases steadily but very slowly. You will run out of compute before convergence.

If η\eta is right: loss decreases steeply at first, then flattens as the model approaches a minimum. The classic exponential-ish curve above.

If η\eta is too large: loss zigzags or grows. The model is bouncing across the loss surface, sometimes landing in worse places than it started.

Try it. The same parabola, but drag the η slider up past 1.0 and hit play:

Drag η to see the regimes. Around 0.1-0.5 the ball converges smoothly. Around 1.0 it overshoots and oscillates. Past 1.0 it diverges entirely.

The standard practice is to start at a learning rate the literature suggests for your model size (e.g. 2e-4 for LoRA, 1e-4 for full fine-tuning, often lower for pretraining) and watch the early training loss. Adjust if it does not look right.

Most modern training also uses a learning-rate schedule: a warmup phase that increases η\eta from 0 to its peak over a few hundred steps, then a decay phase that brings it down again toward the end of training. The warmup helps with numerical stability early on. The decay helps the model settle into a sharper minimum.

Mini-batch and stochastic gradient descent

The formula above has L(θ)=1Ni()L(\theta) = \frac{1}{N}\sum_i \ell(\ldots), summing over the entire dataset. For an LLM, NN is trillions of tokens. Computing the full gradient even once is impractical.

The workaround is to estimate the gradient from a random subset of the data, called a mini-batch:

L(θ)1Bibatch(fθ(xi),yi)\nabla L(\theta) \approx \frac{1}{B} \sum_{i \in \text{batch}} \nabla \ell(f_\theta(x_i), y_i)

where BB is the batch size (typically 32 to a few million tokens worth of examples). This is stochastic gradient descent (SGD). The “stochastic” comes from the randomness of the batch.

The trade-off: smaller batches mean noisier gradient estimates and noisier training, but more gradient steps per unit of data. Larger batches mean smoother gradients and fewer steps per unit of data. Modern pretraining tends to use very large batches because GPUs are good at the matrix multiplications and the gradient noise costs more than it helps.

A useful intuition: batch size and learning rate interact. Double the batch and you can usually double the learning rate without instability, up to a point.

Adam and AdamW

Vanilla SGD has a known weakness: it scales all parameters with the same step size. Parameters with consistently large gradients move too far; parameters with consistently small gradients barely move at all. Adaptive optimisers fix this by giving each parameter its own effective learning rate.

Adam (Kingma and Ba, 2014) is the most common. For each parameter it tracks two moving averages:

mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2

mtm_t is the running mean of the gradient. vtv_t is the running mean of its square (a proxy for variance). The update uses both:

θt+1=θtηmtvt+ϵ\theta_{t+1} = \theta_t - \eta \cdot \frac{m_t}{\sqrt{v_t} + \epsilon}

Parameters with high gradient variance get smaller effective steps; stable parameters get larger ones. β1,β2\beta_1, \beta_2 are usually 0.9 and 0.999. ϵ\epsilon is a tiny constant for numerical safety.

AdamW is Adam with weight decay applied separately from the gradient update. It is the default optimiser for almost every LLM training and fine-tuning run today. If you see a config that says optimizer: adamw, this is what is running underneath.

The loss surface is not convex

A convex loss surface has one minimum. Gradient descent on a convex surface finds it, full stop. The loss surface of a neural network is wildly non-convex. It has many minima, saddle points, valleys, ridges.

with η too large with η small enough ideally Start θ_0 Local minimum Global minimum Saddle point

In practice, two things make gradient descent work despite this:

Large overparameterised networks have many minima that are roughly equivalent in loss. You do not need the global minimum; almost any low-loss minimum gives a useful model.

Stochasticity helps escape bad basins. The noise from mini-batch sampling kicks the parameters out of saddle points and shallow local minima that exact gradient descent would get stuck in.

A loss surface with multiple minima looks more like this. Try starting from different θ values (hit reset, then change the starting point indirectly by lowering η so you take more, smaller steps from the same start, or just watch where it lands):

A non-convex loss surface. Gradient descent will land in the basin closest to where it started. The basin on the right is deeper, but you only get there if your starting point (or your training noise) puts you on the right side of the ridge.

This is part of why training large models with SGD and Adam works at all. The math has no convergence guarantee. Empirically it converges anyway, and the basin it settles into produces a useful model.

What this looks like in code

A vanilla training loop, in PyTorch:

python
import torch
from torch import nn, optim

model = MyModel()
opt = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
loss_fn = nn.CrossEntropyLoss()

for step, batch in enumerate(train_loader):
    inputs, targets = batch
    logits = model(inputs)
    loss = loss_fn(logits, targets)

    opt.zero_grad()
    loss.backward()              # backprop: compute ∇L
    opt.step()                   # update: θ = θ - η ∇L (with Adam math)

    if step % 100 == 0:
        print(f"step {step}: loss {loss.item():.4f}")

Four lines of math: forward, loss, backward, step. Everything else in real training scripts is plumbing: data loading, mixed precision, gradient accumulation, checkpointing, distributed-training synchronisation. The optimisation itself stays at these four lines.

What changes for LLM training

Two things scale up.

The dataset is gigantic (trillions of tokens) and you only see it once or a small number of times. Each gradient step is on a fresh batch, sampled without replacement.

The model is huge, so the gradient computation is the bottleneck. Backprop runs on every device in a cluster, the gradients get averaged across devices (all-reduce), the optimiser update happens, and the next batch comes in. The math from the formulas above is unchanged; the engineering around it is most of the work.

For fine-tuning a model that has already been pre-trained, the algorithm is the same. You just start from a pre-trained θ\theta instead of random initial weights, use a smaller dataset, and usually a smaller learning rate.


← all posts