See also: Machine learning terms, Vanishing gradient problem, Gradient clipping, Backpropagation
The exploding gradient problem is a fundamental difficulty encountered during training of deep neural networks and recurrent neural networks (RNNs) using gradient descent-based optimization. It occurs when gradients of the loss function grow exponentially large as they are propagated backward through the layers of a network during backpropagation, leading to extremely large weight updates that destabilize training. When gradients explode, network weights can overflow to numerical infinity, the loss can diverge to NaN ("Not a Number"), and the model becomes unable to learn meaningful representations from data.
The exploding gradient problem is the companion issue to the vanishing gradient problem. Both arise from the same underlying mechanism (repeated multiplication of Jacobian matrices during backpropagation), but they represent opposite extremes: vanishing gradients occur when gradients shrink toward zero, while exploding gradients occur when gradients grow without bound. Although the vanishing gradient problem received earlier attention in the literature, exploding gradients are equally destructive and can crash training within a single iteration. Understanding and mitigating this problem has been critical to the success of modern deep learning.
The exploding gradient problem was first identified alongside the vanishing gradient problem. In 1991, Sepp Hochreiter formally analyzed gradient flow in deep networks and recurrent architectures in his diploma thesis at the Technische Universitat Munchen, titled Untersuchungen zu dynamischen neuronalen Netzen ("Investigations into Dynamic Neural Networks"), supervised by Jurgen Schmidhuber. Hochreiter showed that error signals propagated through backpropagation either shrink or grow exponentially, identifying both the vanishing and exploding cases. Because this thesis was written in German, it did not circulate widely among the international research community at the time.
In 1994, Yoshua Bengio, Patrice Simard, and Paolo Frasconi published "Learning Long-Term Dependencies with Gradient Descent is Difficult," which independently arrived at similar conclusions. Bengio and colleagues provided both theoretical analysis and experimental evidence showing that recurrent networks trained with gradient descent struggle to capture long-range dependencies due to exponentially growing or shrinking gradients. This paper brought broader attention to these problems and influenced subsequent research directions.
In 2012, Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio published "On the Difficulty of Training Recurrent Neural Networks" (presented at ICML 2013), which provided the most comprehensive analysis of the problem to date. This paper examined exploding and vanishing gradients from three perspectives: analytical (formal conditions linked to the spectral properties of the recurrent weight matrix), geometric (the existence of "cliffs" in the loss surface where gradients change dramatically), and dynamical systems (behavior near the boundary of stability). The paper proposed gradient norm clipping as a practical countermeasure and demonstrated its effectiveness empirically. This work has accumulated thousands of citations and gradient clipping became standard practice in training RNNs and, later, transformers.
The mathematical root of the exploding gradient problem lies in the chain rule of calculus, which governs how gradients are computed during backpropagation.
Consider a deep feedforward network with L layers. The gradient of the loss function L with respect to the weights in layer i requires computing the product of partial derivatives across all intervening layers:
dL/dW_i = (dL/da_L) * (da_L/da_{L-1}) * (da_{L-1}/da_{L-2}) * ... * (da_{i+1}/da_i) * (da_i/dW_i)
Each term da_{k+1}/da_k represents the Jacobian matrix of layer k+1 with respect to the activations of layer k. This Jacobian depends on the weight matrix W_i and the derivative of the activation function used in that layer. When these Jacobian matrices are multiplied together, the resulting product depends on the singular values of the individual matrices. If the largest singular value of the weight matrices is consistently greater than 1, the repeated multiplication causes the product to grow exponentially with network depth, potentially as O(s^L) where s is the spectral norm and L is the network depth. This exponential growth is the mathematical root cause of exploding gradients.
For a simplified illustration, consider a network where each layer applies the same linear transformation W. The gradient involves computing W^L (the weight matrix raised to the L-th power). If the largest eigenvalue of W has magnitude greater than 1, then W^L grows exponentially with L, causing the gradient to explode.
The problem is especially severe in RNNs because the same weight matrix is applied at every time step. For an RNN processing a sequence of length T, the gradient of the loss at time T with respect to the hidden state at time t involves computing:
dh_T/dh_t = product from k=t to T-1 of (dh_{k+1}/dh_k) = product from i=t+1 to T of diag(f'(h_{i-1})) * W_rec
where W_rec is the recurrent weight matrix and f' denotes the derivative of the activation function. Each factor is the Jacobian of the recurrent transition function, which depends on the recurrent weight matrix W_h and the derivative of the activation function.
As Pascanu et al. (2013) showed formally, a sufficient condition for gradients to explode is that the spectral radius of the recurrent weight matrix exceeds 1/sigma_max, where sigma_max is the largest value of the activation function's derivative. Equivalently, if the largest singular value of the recurrent weight matrix multiplied by the maximum value of the activation function's derivative exceeds 1, there exist directions in state space along which the gradient norm grows exponentially with the sequence length T - t. In practice, for RNNs with tanh activations (where the maximum derivative is 1), the spectral radius of W_h being greater than 1 is sufficient to cause gradient explosion over long sequences.
The exploding and vanishing gradient problems are two sides of the same coin. Both originate from the repeated multiplication of Jacobian matrices during backpropagation, but they have different symptoms, different dominant contexts, and somewhat different solutions.
| Aspect | Exploding gradients | Vanishing gradients |
|---|---|---|
| Gradient behavior | Grows exponentially toward infinity | Shrinks exponentially toward zero |
| Singular value condition | Largest singular value > 1 | Largest singular value < 1 |
| Weight condition | Spectral radius > 1 | Spectral radius < 1 |
| Primary symptom | NaN/Inf loss, diverging weights | Training stalls, early layers stop learning |
| Effect on weights | Weights receive extremely large, erratic updates | Weights in early layers receive negligible updates |
| Training outcome | Network training diverges or oscillates wildly | Network fails to learn long-range dependencies |
| Ease of detection | Easier (catastrophic, visible failure) | Harder (silent, gradual stalling) |
| Most affected layers | Earlier layers (receive amplified gradients) | Earlier layers (receive diminished gradients) |
| Common in | Deep networks, RNNs with long sequences | Deep networks, RNNs with long sequences |
| Most direct fix | Gradient clipping, weight regularization, normalization | Gated architectures, ReLU activations, skip connections |
Both problems are addressed by overlapping sets of solutions. Batch normalization, residual connections, and proper initialization help with both. Gradient clipping specifically targets exploding gradients, while LSTM/GRU gating mechanisms were originally designed primarily to address vanishing gradients (though they help with both). It is worth noting that a network can experience both problems simultaneously in different parts of its parameter space. Some gradient components may vanish while others explode, particularly in large networks with varying weight magnitudes across layers.
The main cause of the exploding gradient problem can be traced back to the process of backpropagation used in training artificial neural networks. In backpropagation, gradients of the loss function are computed with respect to each parameter in the network, starting from the output layer and moving backward through the network's layers. During this process, the gradients are multiplied by the weights of the connections between the layers. If these weights are consistently large, the gradients can grow exponentially, leading to the exploding gradient problem. Deep networks and RNNs are particularly susceptible to this issue because of the increased number of layers and recurrent connections, which allow gradients to accumulate and grow rapidly.
Several specific factors contribute to exploding gradients:
| Cause | Description | Example |
|---|---|---|
| Large weight matrices | When weights are initialized too large or grow during training, Jacobian matrices have large singular values. | A 50-layer network with 1.5x amplification per layer yields 1.5^50 (roughly 637,621x) amplification. |
| Deep architectures | Each additional layer multiplies another factor into gradient computation. Networks with tens or hundreds of layers are highly susceptible. | Very deep CNNs or stacked transformer layers without residual connections. |
| Long sequences in RNNs | The same recurrent weight matrix applies at every time step. Longer sequences produce more severe explosions when the spectral radius exceeds 1. | Language modeling or speech recognition tasks with sequences of hundreds or thousands of tokens. |
| Poor weight initialization | Random initialization without accounting for fan-in and fan-out can easily produce weight matrices with singular values greater than 1. | Uniform random initialization without scaling for layer width. |
| High learning rate | Large gradient updates increase weights, causing even larger gradients in subsequent iterations, creating a positive feedback loop. | Selecting a learning rate that is orders of magnitude too large for the given architecture. |
| Activation function choice | Unbounded activation functions or poor weight-activation combinations can produce large Jacobian entries. | Using linear activations in very deep networks without normalization. |
When network weights are initialized with values that are too large, or when weights grow large during training, the Jacobian matrices at each layer have large singular values. Even modest weight magnitudes can cause problems in very deep networks because the multiplicative effect compounds across many layers. For example, if each layer amplifies the gradient by a factor of just 1.5, a 50-layer network would amplify gradients by 1.5^50, which is approximately 637,621. This exponential amplification quickly leads to numerical overflow.
The depth of a neural network directly influences the severity of the exploding gradient problem. Each additional layer adds another multiplicative factor to the gradient computation. While shallow networks (with one or two hidden layers) rarely exhibit exploding gradients, architectures with tens or hundreds of layers are highly susceptible. This was one of the primary obstacles to training very deep networks before the introduction of residual connections and normalization techniques.
Recurrent neural networks are especially prone to exploding gradients because the same weight matrix is applied at every time step during backpropagation through time (BPTT). The gradient computation for a sequence of length T involves multiplying the recurrent weight matrix by itself T times (or more precisely, multiplying the Jacobian at each time step). If the spectral radius of the recurrent weight matrix exceeds 1, the gradient grows exponentially with the sequence length. This means that longer sequences produce more severe gradient explosions, making it difficult to train RNNs on tasks requiring long-range temporal reasoning.
The choice of weight initialization scheme has a significant impact on gradient behavior during the early stages of training. Random initialization with a standard normal distribution or uniform distribution that does not account for the network's fan-in and fan-out can easily produce weight matrices whose singular values are greater than 1. This sets up the conditions for gradient explosion from the very first training steps. Initialization strategies such as Xavier initialization and He initialization were specifically designed to address this issue.
A learning rate that is too high can indirectly contribute to exploding gradients. When large gradient updates are applied to the weights, the weights themselves can grow larger, which in turn produces even larger gradients in subsequent iterations. This creates a positive feedback loop where large gradients cause large weight updates, which cause even larger gradients, eventually leading to divergence. The interaction between learning rate and gradient magnitude is one reason why learning rate scheduling and warmup strategies are commonly used in deep learning.
Certain activation functions can contribute to gradient explosion. While saturating functions like sigmoid and tanh are more commonly associated with vanishing gradients (because their derivatives are bounded and often less than 1), unbounded activation functions or poor combinations of activations and weight magnitudes can lead to large Jacobian entries. In recurrent networks, even tanh activations can produce exploding gradients when the recurrent weight matrix has a sufficiently large spectral radius, because the weight matrix multiplication dominates the activation derivative.
Recognizing exploding gradients early is essential for preventing wasted computation and debugging training failures.
The most common indicators of exploding gradients include:
The following table summarizes these symptoms with their typical severity:
| Symptom | Description | Severity |
|---|---|---|
| NaN loss values | The loss function returns NaN (not a number), indicating numerical overflow in the computation | Critical: training must be stopped |
| Inf loss values | The loss grows to infinity, suggesting that weight updates have pushed parameters to extreme values | Critical: training must be stopped |
| Large oscillations in loss | The loss fluctuates wildly between iterations instead of decreasing smoothly | High: model is unlikely to converge |
| Extremely large weight values | Model parameters grow to very large magnitudes (e.g., 10^6 or higher) | High: typically precedes NaN/Inf |
| Gradient norms above expected | Gradient norms of 10^3 to 10^6 or beyond, far exceeding typical values of 0.1 to 10 | High: clipping or other intervention needed |
| Meaningless predictions | Network outputs constant values, all zeros, all ones, or random noise | High: model has lost learned representations |
| Instability after initial progress | The model trains normally for some time, then suddenly diverges | Moderate to High: may indicate a "cliff" in the loss landscape |
Practitioners use several techniques to detect and monitor for exploding gradients:
Multiple techniques have been developed to address the exploding gradient problem. In practice, most modern training pipelines combine several of these approaches simultaneously. These range from simple heuristics applied during training to fundamental architectural changes that address the root cause.
Gradient clipping is the most direct and widely used solution, proposed by Pascanu, Mikolov, and Bengio (2013) as part of their analysis of training difficulties in recurrent neural networks. The core idea is simple: if the gradient exceeds a predefined threshold, it is scaled down before being used to update the weights. There are two main variants.
Clipping by value independently clips each gradient component to a fixed range [min_value, max_value]. For each gradient component g_i: if g_i > max_threshold, set g_i = max_threshold; if g_i < min_threshold, set g_i = min_threshold; otherwise g_i remains unchanged. The drawback of this approach is that it changes the direction of the gradient vector, which can interfere with optimization. Because each gradient component is clipped independently, the clipped gradient may no longer point in the direction of steepest descent.
Clipping by norm (also called gradient norm scaling) rescales the entire gradient vector if its L2 norm exceeds a threshold c, preserving the gradient direction while reducing its magnitude:
if ||g|| > c, then g_new = (c / ||g||) * g
This is the more commonly used variant because it preserves the relative proportions between gradient components. Because the direction is maintained, the optimizer still moves toward the correct region of parameter space, just with a smaller step. Common threshold values range from 0.5 to 10.0, with 1.0 being a typical default. A practical approach for selecting the threshold is to observe typical gradient norms during the early stages of training without clipping, then set the threshold somewhat above that range.
Choosing an appropriate clipping threshold is important and somewhat problem-dependent. Common practices include:
| Domain | Typical Clipping Threshold | Notes |
|---|---|---|
| RNN training | 1.0 to 5.0 | Pascanu et al. (2013) used values in this range |
| Transformer / LLM training | 0.5 to 2.0 | A threshold of 1.0 is the most common default |
| Computer vision (CNNs) | 1.0 to 10.0 | Larger thresholds are typical for CNNs |
| General starting point | 1.0 | A reasonable default across many settings |
Monitoring the frequency of clipping events is also helpful: if gradients are clipped on every batch, the threshold may be too low and is limiting the model's ability to learn. If clipping never activates, the threshold may be too high to provide any benefit.
Proper weight initialization prevents gradients from exploding at the start of training by ensuring that the variance of activations and gradients remains approximately constant across layers.
Xavier (Glorot) initialization (Glorot and Bengio, 2010), proposed in "Understanding the difficulty of training deep feedforward neural networks," sets the variance of weights to Var(W) = 2 / (n_in + n_out), where n_in is the number of inputs to the layer and n_out is the number of outputs. This ensures that the variance of both forward-pass activations and backward-pass gradients remains approximately constant across layers. Xavier initialization works well with linear activations and saturating nonlinearities like sigmoid and tanh.
He (Kaiming) initialization (He et al., 2015) sets the variance of weights to Var(W) = 2 / n_in, accounting for the fact that ReLU activations zero out approximately half of their inputs. The factor of 2 compensates for the fact that ReLU sets approximately half of its inputs to zero, which would otherwise reduce the signal variance by half at each layer. This approach was introduced in the paper "Delving Deep into Rectifiers" and is the standard initialization for networks using ReLU or its variants (Leaky ReLU, PReLU, ELU).
Both initialization strategies work by ensuring that the expected gain (the ratio of output variance to input variance) is approximately 1 at each layer. When the gain is above 1, gradients tend to explode; when it is below 1, gradients tend to vanish.
| Initialization method | Formula | Best suited for | Year introduced |
|---|---|---|---|
| Xavier (Glorot) | Var(W) = 2 / (n_in + n_out) | Sigmoid, tanh activations | 2010 |
| He (Kaiming) | Var(W) = 2 / n_in | ReLU-family activations | 2015 |
| LeCun | Var(W) = 1 / n_in | SELU activations | 1998 |
| Orthogonal | W = orthogonal matrix | RNNs (preserves gradient norms) | Various |
Batch normalization (Ioffe and Szegedy, 2015), introduced in "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift," normalizes the inputs to each layer within a mini-batch to have zero mean and unit variance, then applies learnable scale and shift parameters:
y = gamma * ((x - mu_batch) / sqrt(sigma_batch^2 + epsilon)) + beta
where mu_batch and sigma_batch are the batch mean and standard deviation, gamma and beta are learned parameters, and epsilon is a small constant for numerical stability. By preventing the signal from growing unboundedly across layers, batch normalization reduces the risk of gradient explosion and enables the use of higher learning rates. Empirically, batch normalization speeds convergence and reduces sensitivity to weight initialization.
Layer normalization (Ba, Kiros, and Hinton, 2016) normalizes across the feature dimension rather than the batch dimension. This makes it suitable for RNNs and transformers, where batch statistics are less meaningful due to variable sequence lengths and autoregressive processing or where batch sizes may be small. Layer normalization is the standard normalization technique in transformer architectures.
RMSNorm (Zhang and Sennrich, 2019) simplifies layer normalization by removing the mean-centering step and normalizing only by the root mean square of the activations. It has been adopted in several large language models due to its computational efficiency.
Long Short-Term Memory (LSTM) networks, introduced by Hochreiter and Schmidhuber (1997), were designed specifically to address the vanishing and exploding gradient problems in RNNs. LSTMs introduce a cell state that runs through time with additive (rather than multiplicative) updates:
c_t = f_t * c_{t-1} + i_t * c_tilde_t
The additive nature of this update (rather than multiplicative) allows gradients to flow backward through many time steps without repeated multiplication by the weight matrix. Three gates regulate signal flow:
| Gate | Function | Effect on Gradients |
|---|---|---|
| Forget gate | Controls how much of the previous cell state to retain | Allows the gradient to pass through unchanged when the gate is open |
| Input gate | Controls how much new information enters the cell state | Prevents large activations from overwhelming the cell state |
| Output gate | Controls how much of the cell state is exposed | Regulates the magnitude of the hidden state used in downstream computations |
The forget gate bias is often initialized to a value near 1 to encourage gradient flow early in training. When the forget gate is close to 1, the gradient flows through almost unimpeded, acting as a "gradient highway."
Gated Recurrent Units (GRUs), proposed by Cho et al. (2014), achieve similar gradient stabilization benefits with a simpler two-gate architecture (reset and update gates). The update gate in the GRU functions similarly to the forget and input gates in the LSTM combined, controlling the balance between the previous hidden state and the new candidate state. GRUs provide comparable gradient flow benefits with fewer parameters.
Both LSTMs and GRUs significantly reduce, but do not entirely eliminate, gradient problems. Gradient clipping is still commonly used alongside these architectures when training on very long sequences.
Residual connections (also called skip connections), introduced by He et al. (2016) in the ResNet architecture, add a shortcut path that bypasses one or more layers:
y = F(x) + x
where F(x) represents the output of one or more layers and x is the input to that block. During backpropagation, the gradient flows through both the residual path (through F) and the identity shortcut (directly). Because the derivative of the identity function x is 1, the gradient through the shortcut path flows backward without any multiplicative degradation or amplification.
This architectural innovation was transformative for deep learning. Before residual connections, training networks with more than approximately 20 layers was extremely difficult because of gradient degradation. With residual connections, researchers successfully trained networks with over 1,000 layers. Residual connections are now a fundamental component of virtually all modern deep architectures, including transformers, which use residual connections around both the self-attention and feed-forward sublayers.
Careful learning rate selection is critical for preventing gradient explosion. Several strategies help:
Mixed-precision training uses lower-precision floating-point formats (such as FP16 or BF16) to accelerate computation. However, the limited numerical range of these formats means that very large gradient values can overflow to infinity, while very small values can underflow to zero. Loss scaling addresses this by multiplying the loss by a large constant before backpropagation, keeping gradient values in the representable range of the lower-precision format, and then dividing by the same constant after the gradient computation but before the weight update.
Dynamic loss scaling adjusts this scaling factor automatically during training. It starts with a high loss scale and reduces it whenever gradient overflow is detected. This approach is standard in modern frameworks such as PyTorch's torch.cuda.amp.GradScaler and NVIDIA's Apex library.
Weight regularization techniques penalize large weight values during training, which indirectly helps prevent the conditions that lead to exploding gradients.
L2 regularization (also called weight decay) adds a penalty term proportional to the sum of squared weights to the loss function: L_total = L_original + (lambda/2) * sum(W^2), where lambda is the regularization strength. By penalizing large weight values, L2 regularization keeps the singular values of weight matrices closer to 1, reducing the risk of gradient explosion. In each weight update, L2 regularization effectively shrinks each weight by a small fraction, pulling weights toward zero.
L1 regularization adds a penalty proportional to the absolute value of the weights. While L1 regularization promotes sparsity rather than simply keeping weights small, it also limits the overall magnitude of the weight matrices.
Decoupled weight decay (Loshchilov and Hutter, 2019), as implemented in the AdamW optimizer, applies weight decay independently of the gradient computation. With modern adaptive optimizers such as Adam, the distinction between L2 regularization and weight decay becomes important, because Adam's adaptive scaling of gradients interacts poorly with the L2 penalty term. Decoupled weight decay has been shown to outperform standard L2 regularization in practice and is the default optimizer for most transformer-based model training.
Some approaches directly constrain the gradient norm as part of the training objective. For example, in training Wasserstein GANs with gradient penalty (WGAN-GP), a penalty term is added to the loss function that encourages the norm of the gradients to stay close to 1. While this technique was developed for a specific application (generative adversarial networks), the principle of explicitly penalizing large gradient norms can be applied more broadly to stabilize training.
For recurrent networks processing very long sequences, truncated backpropagation through time (TBPTT) limits the number of time steps over which gradients are propagated. Instead of computing gradients across the entire sequence, TBPTT only propagates gradients for a fixed window of k time steps. This directly limits the number of matrix multiplications in the gradient computation, reducing the potential for gradient explosion. The trade-off is that the network cannot learn dependencies longer than k steps through gradient-based optimization.
One of the most influential studies of the exploding gradient problem is the 2013 paper "On the Difficulty of Training Recurrent Neural Networks" by Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio, presented at the 30th International Conference on Machine Learning (ICML 2013).
Pascanu et al. analyzed the exploding and vanishing gradient problems from three complementary perspectives:
Based on this analysis, Pascanu et al. proposed gradient norm clipping as a simple and effective countermeasure. Their clipping strategy rescales the entire gradient vector when its norm exceeds a threshold, which prevents the optimizer from taking catastrophically large steps when it encounters a cliff in the loss surface. They also proposed a complementary soft constraint for the vanishing gradient problem. The paper demonstrated empirically that gradient clipping allowed successful training of RNNs on tasks that were previously intractable. The paper has been cited thousands of times and gradient clipping has become a standard practice in neural network training.
The exploding gradient problem has broad implications for training stability and the practical design of deep learning systems.
The exploding gradient problem can have several detrimental consequences on the training process and the performance of the resulting model:
The exploding gradient problem interacts strongly with the choice of learning rate. A learning rate that would be appropriate for normal gradient magnitudes becomes far too large when gradients explode. This is why learning rate warmup, where the learning rate starts very small and gradually increases over the first few thousand training steps, has become standard practice. The warmup period allows the model to find a stable region of the loss landscape before applying full-magnitude updates. Learning rate scheduling strategies such as cosine decay and linear decay also help manage the interaction between learning rates and gradient magnitudes throughout training.
Modern adaptive optimizers like Adam, AdaGrad, and RMSProp maintain per-parameter learning rates based on the history of gradient magnitudes. These optimizers provide some implicit protection against exploding gradients because a parameter that consistently receives large gradients will have its effective learning rate reduced. However, this protection is not absolute. Sudden gradient spikes (such as those caused by encountering a "cliff" in the loss landscape) can still overwhelm adaptive optimizers because the running average of gradient magnitudes may not yet reflect the spike. For this reason, gradient clipping is typically used alongside adaptive optimizers rather than as a replacement.
Although modern architectures incorporate residual connections, layer normalization, and careful initialization, training very large language models is still not immune to gradient instability. While the exploding gradient problem was first extensively studied in the context of RNNs and deep feedforward networks, it remains highly relevant in the era of transformers and large language models.
Transformer architectures incorporate several design features that help manage gradient flow, including residual connections, layer normalization, and multi-head attention. Despite these built-in safeguards, transformer training is not immune to gradient instability, particularly at very large scales.
Research has shown that the backwards gradients from the query and key (Q/K) matrices in the attention mechanism are exponentially related to their variance. Incorrect initialization of these matrices can cause gradient explosion even in models with residual connections and normalization. For example, initializing Q/K matrices at just twice the standard Xavier values was shown to cause backwards gradients to explode by a factor of 10,000 through a 192-layer model.
Another source of instability in transformers is the amplification effect in residual branches. When the residual branch contributes a large fraction of the total signal (relative to the skip connection), small perturbations in the parameters can be amplified through the network, leading to training instability. This is why the ordering and placement of normalization layers (Pre-LN vs. Post-LN) has a significant effect on transformer training stability.
Loss spikes are a well-documented phenomenon during large language model pre-training. They occur when unusual or adversarial training examples trigger sudden gradient explosions, causing the loss to spike dramatically before (in favorable cases) recovering. In practice, loss spikes can cause a model to lose much of its learned representation, requiring rollback to an earlier checkpoint.
Google's PaLM (540 billion parameters) experienced roughly 20 loss spikes during training despite gradient clipping being enabled. The spikes occurred at highly irregular intervals and sometimes appeared late into training. The PaLM team's mitigation strategy involved restarting training from a checkpoint roughly 100 steps before the spike and skipping 200 to 500 data batches covering the batches seen before and during the spike. After restarting, the loss did not spike again at the same point.
Common strategies for handling loss spikes in LLM training include:
Recent research has introduced adaptive gradient clipping methods that dynamically adjust the clipping threshold based on recent gradient statistics. ZClip, for example, uses exponential moving averages of gradient norms to set adaptive clipping thresholds, providing automatic protection against loss spikes without requiring manual threshold tuning.
Research has identified two primary factors behind training instability in very large models: (1) rapid amplification of the norm of the residual stream during forward propagation and (2) intensification of gradients before and after layer normalizations. These spikes can produce gradients 1,000 times larger than typical values.
Initialization also plays a critical role. Initializing attention sub-layers with parameters that are too large (for example, Q/K matrices initialized at twice standard values) can cause backward gradients to explode by 10,000x through a 192-layer model.
As foundation models scale to billions or trillions of parameters, the risk of gradient instability increases. Larger models have more parameters and more layers, providing more opportunities for gradients to grow. The use of mixed-precision training (where computations are performed in 16-bit or lower precision floating point) further complicates gradient management, because the reduced numerical range of lower-precision formats makes overflow more likely. Techniques like loss scaling (multiplying the loss by a large constant before backpropagation and dividing the gradients afterward) are used to keep gradient values within the representable range of the floating point format.
Distributed training across many GPUs or TPUs introduces additional challenges, as gradient aggregation across devices can amplify numerical errors. Gradient clipping in distributed settings must be applied to the globally aggregated gradient, not to individual per-device gradients, to ensure consistent behavior.
| Technique | Description |
|---|---|
| Gradient clipping (norm) | Clip global gradient norm, typically with threshold 1.0. |
| Learning rate warmup | Gradually increase learning rate over initial training steps. |
| Real-time gradient norm monitoring | Track gradient norms per step; alert on anomalies. |
| Checkpoint rollback | Revert to earlier checkpoint when loss spikes are detected. |
| Batch skipping | Skip anomalous batches that trigger gradient spikes. |
| Auxiliary loss terms | Penalize large logit values to prevent output-layer instability. |
| Small initialization scales | Initialize sub-layer parameters with reduced standard deviations. |
| Spike-aware optimizers | Algorithms like SPAM that detect and reset momentum on gradient spikes. |
Imagine you are playing a game of telephone with a very long line of people. You whisper a number to the first person, and each person multiplies the number by 2 before passing it to the next person. After just 10 people, the number 1 becomes 1,024. After 20 people, it becomes over a million. After 30 people, it becomes over a billion. The number gets so large that nobody can even say it anymore.
That is what happens with exploding gradients in a neural network. Each layer of the network multiplies a correction signal as it passes backward through the network. If each layer makes the signal a little bit bigger, by the time it reaches the early layers the signal has grown so enormous that the network makes wild, nonsensical adjustments instead of careful, useful ones. The network "forgets" how to learn because the correction signals are just too big.
The fix is like telling each person in the telephone line: "If the number gets bigger than 100, just say 100 instead." That is gradient clipping, and it keeps the numbers manageable so the network can learn properly. Another way to think about it: imagine you are trying to learn how to stack blocks into a big tower, and each block needs just the right amount of force to stay in place. If the force you use at the bottom can amplify through each level, too much force anywhere can topple the whole tower. Putting a limit on how much force you can apply at any level keeps the tower standing.
| Technique | Category | Typical application | Key benefit |
|---|---|---|---|
| Gradient clipping by norm | Training-time | RNN and transformer training | Directly limits gradient magnitude; preserves direction |
| Gradient clipping by value | Training-time | General purpose | Simple implementation; caps individual components |
| Xavier initialization | Initialization | Sigmoid/tanh networks | Prevents initial gradient instability |
| He initialization | Initialization | ReLU networks | Adjusts variance for ReLU activations |
| Batch normalization | Architecture | CNNs, feedforward networks | Stabilizes forward and backward signal |
| Layer normalization | Architecture | Transformers, RNNs | Stabilizes without batch dependency |
| LSTM/GRU gating | Architecture | Sequence modeling | Additive cell updates preserve gradients |
| Residual connections | Architecture | Deep networks (100+ layers) | Identity shortcut ensures gradient flow |
| L2 regularization / weight decay | Regularization | General purpose, especially with SGD | Keeps weight magnitudes bounded |
| Learning rate warmup | Training-time | LLM and transformer pre-training | Prevents early-training instability |
| Mixed-precision loss scaling | Training-time | Low-precision training (FP16/BF16) | Keeps gradients in representable range |
| Truncated BPTT | Training-time | RNN training on long sequences | Limits number of matrix multiplications |