See also: gradient descent, backpropagation, loss function, mixed precision training, numerical stability
A NaN trap (short for "Not a Number" trap) is a failure mode in machine learning training where arithmetic operations produce the special IEEE 754 value NaN, which then propagates through all subsequent computations and renders the model's loss, gradients, and weights meaningless. Once a single NaN enters the computation graph, every downstream operation that touches it also becomes NaN, effectively halting any useful learning. The term "trap" reflects the self-reinforcing nature of the problem: NaN values are contagious and, without explicit detection, silently corrupt the entire training run.
NaN traps are one of the most common and frustrating bugs in deep learning. They can appear suddenly after thousands of stable training steps, and the root cause is often far removed from the point where the NaN first becomes visible. A model may report a finite loss for hours, then abruptly display loss: nan with no obvious trigger. Diagnosing the source requires understanding how floating-point arithmetic works, which operations are prone to producing undefined results, and how the interaction between learning rate, model architecture, and data preprocessing can push computations outside representable ranges.
This article covers the IEEE 754 foundations of NaN values, the specific causes of NaN traps in neural network training, techniques for prevention and detection, and practical debugging strategies.
To understand NaN traps, it helps to know how computers represent real numbers. Modern hardware uses the IEEE 754 standard for floating-point arithmetic, which defines several special values beyond ordinary numbers.
| Format | Exponent bits | Mantissa bits | Approximate range | Typical use |
|---|---|---|---|---|
| FP32 (float32) | 8 | 23 | ~1.2 x 10^-38 to ~3.4 x 10^38 | Default training precision |
| FP16 (float16) | 5 | 10 | ~6.0 x 10^-8 to ~65,504 | Mixed precision training, inference |
| BF16 (bfloat16) | 8 | 7 | ~1.2 x 10^-38 to ~3.4 x 10^38 | TPU and modern GPU training |
| FP64 (float64) | 11 | 52 | ~2.2 x 10^-308 to ~1.8 x 10^308 | Scientific computing, accumulation |
The narrower the format, the more likely a computation will exceed its representable range. FP16 can only represent values up to 65,504 before overflowing to infinity, and values smaller than roughly 6.0 x 10^-8 underflow to zero. Both overflow and underflow can lead to NaN through subsequent operations.
The IEEE 754 standard specifies that certain operations yield NaN because their results are mathematically undefined or indeterminate.
| Operation | Example | Result | Why |
|---|---|---|---|
| 0 / 0 | 0.0 / 0.0 | NaN | Indeterminate form |
| infinity - infinity | inf - inf | NaN | Indeterminate form |
| 0 x infinity | 0.0 * inf | NaN | Indeterminate form |
| infinity / infinity | inf / inf | NaN | Indeterminate form |
| sqrt of negative | sqrt(-1.0) | NaN | No real result |
| Remainder by zero | x % 0 | NaN | Undefined |
The defining property of NaN that makes it a "trap" is its contagious behavior. Any arithmetic operation involving a NaN operand produces NaN as its result:
NaN + x = NaN for any finite xNaN * x = NaN for any x (including zero, unlike infinity)NaN < x is false, and NaN > x is false, and NaN == NaN is also falseIn a neural network, this means a single NaN generated in one layer's forward pass will corrupt every activation in subsequent layers, the loss value, all gradients during backpropagation, and ultimately every weight update. After a single NaN weight update, the model's parameters are permanently corrupted unless training is rolled back to a previous checkpoint.
NaN traps in deep learning training arise from several interacting factors. The sections below cover each cause in detail.
Exploding gradients are the most common proximate cause of NaN traps. During backpropagation, gradients are computed as products of Jacobian matrices across layers. In a network with L layers, the gradient of the loss with respect to early-layer parameters involves multiplying L matrices together. If the spectral norm of these matrices exceeds 1.0, the product grows exponentially with depth, and gradient magnitudes can reach 10^7 or higher within a few training steps before overflowing to infinity, then producing NaN when infinity is used in subsequent operations like subtraction or division.
Research on large language model training has shown that loss spikes (sudden jumps in training loss that sometimes lead to complete divergence) are preceded by exponential gradient norm growth. The gradient norm can spike by several orders of magnitude in just a handful of steps, and if left unchecked, the model's weights overflow to infinity or NaN.
Vanishing gradients are the opposite problem: gradients shrink exponentially as they propagate backward through many layers. While vanishing gradients alone do not directly produce NaN, they contribute to NaN traps indirectly. When gradients underflow to zero (especially in FP16), subsequent operations such as division by a zero gradient estimate or normalization by a zero variance can produce NaN.
The sigmoid activation function is particularly prone to this issue. Its derivative approaches zero for large positive or negative inputs, creating near-zero gradients that provide almost no learning signal. Glorot and Bengio (2010) demonstrated that deep networks with sigmoid activations frequently exhibit saturated units whose gradients are effectively zero, leading to stalled training and numerical instability.
Several common loss functions contain operations that produce NaN when given extreme inputs.
Cross-entropy loss. The cross-entropy loss involves computing log(p) where p is a predicted probability. If p reaches exactly 0.0, log(0) evaluates to negative infinity. If both the prediction and the target are zero, terms like 0 * log(0) produce 0 * (-inf) = NaN. Similarly, in binary cross-entropy, terms like log(1 - p) become -inf when the prediction saturates to exactly 1.0.
KL divergence. The Kullback-Leibler divergence involves p * log(p / q). When q = 0 and p > 0, the expression evaluates to p * log(p / 0) = p * inf = inf, and further arithmetic involving this infinity can produce NaN.
Mean squared error with overflow. While MSE itself is simple, squaring very large prediction errors can overflow to infinity in reduced-precision formats. If a model prediction in FP16 reaches 60,000 and the target is 0, then (60000)^2 = 3.6 x 10^9, which far exceeds the FP16 maximum of 65,504.
An excessively high learning rate is one of the most frequent triggers for NaN traps. Large weight updates cause activations to grow, which produces larger gradients, which causes even larger weight updates. This positive feedback loop can escalate from normal training to NaN in just a few iterations. In transformer models, this instability is especially acute because of the multiplicative interactions in the attention mechanism, where the softmax function amplifies large logit values.
If weights are initialized with values that are too large, the activations in the first forward pass can overflow. If weights are initialized too close to zero, gradients may underflow. Proper initialization schemes exist to address this:
| Initialization | Formula (variance) | Best for | Source |
|---|---|---|---|
| Xavier / Glorot | sigma^2 = 2 / (n_in + n_out) | Tanh, sigmoid activations | Glorot and Bengio (2010) |
| He / Kaiming | sigma^2 = 2 / n_in | ReLU activations | He et al. (2015) |
| LeCun | sigma^2 = 1 / n_in | SELU activations | LeCun et al. (1998) |
Xavier initialization preserves the variance of activations and gradients across layers by setting the weight variance based on the fan-in and fan-out of each layer. He initialization modifies this for ReLU networks, accounting for the fact that ReLU zeros out roughly half of its inputs. Using the wrong initialization scheme (for example, Xavier initialization with ReLU activations) can lead to gradual variance growth or shrinkage that eventually triggers instability.
Bad input data is a surprisingly common cause of NaN traps.
Mixed precision training, which uses FP16 or BF16 for forward and backward passes to improve speed and reduce memory, introduces additional NaN risks. The FP16 format has a maximum value of 65,504, and during backpropagation, gradient values can easily exceed this threshold. Micikevicius et al. (2018) introduced the technique of loss scaling to address this: the loss is multiplied by a large factor (for example, 2^24) before backpropagation to shift small gradients into the representable range, then the gradients are divided by the same factor before the weight update. If the scaling factor is too large, gradients overflow to infinity or NaN instead.
BF16 (bfloat16) mitigates this problem by using the same 8 exponent bits as FP32, giving it the same dynamic range (~10^-38 to ~10^38). However, BF16 has only 7 mantissa bits, reducing precision and making it more susceptible to rounding errors that accumulate over many training steps.
| Precision issue | FP16 risk | BF16 risk | FP32 risk |
|---|---|---|---|
| Gradient overflow | High (max 65,504) | Low (same range as FP32) | Very low |
| Gradient underflow | Moderate | Moderate | Low |
| Accumulation rounding | Moderate | High (only 7 mantissa bits) | Low |
| Needs loss scaling | Usually yes | Usually no | No |
Batch normalization and layer normalization involve dividing by the standard deviation of activations. If all activations in a batch (or layer) are identical, the variance is zero, and the division produces infinity or NaN. Both operations add a small epsilon (typically 1e-5) to the denominator to prevent this, but edge cases can still arise:
Adaptive optimizers like Adam maintain running estimates of gradient means and variances. The update rule involves dividing by the square root of the second moment estimate plus epsilon: parameter = parameter - lr * m_hat / (sqrt(v_hat) + epsilon). If v_hat is extremely small or underflows to zero (especially in FP16), the division produces extremely large values or NaN. The default epsilon of 1e-8 works well in FP32 but may be too small for FP16 or BF16 training, where increasing it to 1e-6 or 1e-7 can prevent NaN.
Multiple strategies exist for preventing NaN traps. In practice, robust training pipelines combine several of these techniques.
Gradient clipping caps the magnitude of gradients before the optimizer step, preventing them from growing large enough to cause overflow.
Norm clipping (global norm clipping) treats all parameter gradients as a single vector, computes its L2 norm, and scales the entire vector down if the norm exceeds a threshold. This preserves the direction of the gradient while limiting its magnitude. In PyTorch:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Value clipping (element-wise clipping) clips each individual gradient component independently to a fixed range (for example, [-1.0, 1.0]). This is simpler but does not preserve gradient direction, which can slow convergence.
| Clipping method | How it works | Preserves direction | Common threshold |
|---|---|---|---|
| Norm clipping | Scales gradient vector if L2 norm exceeds threshold | Yes | 0.5 to 5.0 (commonly 1.0) |
| Value clipping | Clips each component to [-c, c] | No | 0.5 to 1.0 |
Norm clipping is the standard approach for training large language models and transformers. A max_norm of 1.0 is a typical starting point.
Using the correct weight initialization scheme for the chosen activation function keeps activations and gradients in a stable range from the start of training. Xavier initialization is appropriate for tanh and sigmoid activations; He initialization is appropriate for ReLU and its variants. Using He initialization with ReLU prevents the systematic variance shrinkage that would otherwise occur because ReLU zeros out approximately half of its inputs.
Many common operations have numerically stable reformulations that avoid overflow, underflow, or indeterminate forms.
Log-sum-exp trick. The softmax function involves computing exp(x_i) / sum(exp(x_j)). If any x_i is large (say, 1000), exp(1000) overflows to infinity. The standard fix subtracts the maximum value before exponentiating: softmax(x)_i = exp(x_i - max(x)) / sum(exp(x_j - max(x))). This is mathematically equivalent due to the shift-invariance property of softmax, but it guarantees the largest exponent is exp(0) = 1, preventing overflow. Blanchard et al. (2021) provided a rigorous analysis of this technique.
Safe logarithm. Adding a small epsilon before taking the log prevents log(0) = -inf:
# Unsafe
loss = -torch.log(predictions)
# Safe
epsilon = 1e-8
loss = -torch.log(predictions + epsilon)
Fused softmax-cross-entropy. Most frameworks provide a combined softmax and cross-entropy operation (for example, torch.nn.CrossEntropyLoss in PyTorch or tf.nn.softmax_cross_entropy_with_logits in TensorFlow) that operates on raw logits rather than probabilities. These fused operations use the log-sum-exp trick internally and avoid computing probabilities that could be exactly zero.
Kahan summation. When accumulating many small values (for example, summing gradients across a large batch), standard floating-point addition loses precision because small values are rounded away when added to a much larger running sum. Kahan summation maintains a separate compensation variable to track lost low-order bits, effectively doubling the precision of the accumulation. This technique has become especially relevant for pure BF16 training, where the 7-bit mantissa makes accumulation errors significant.
Dynamic loss scaling adjusts the scaling factor automatically during training. The process works as follows:
This approach "rides the edge" of the highest usable loss scale to maximize use of the FP16 dynamic range while avoiding overflow. PyTorch provides torch.amp.GradScaler for this purpose, and TensorFlow provides tf.keras.mixed_precision.LossScaleOptimizer.
Batch normalization (Ioffe and Szegedy, 2015) and layer normalization (Ba et al., 2016) re-center and re-scale activations at each layer, preventing the gradual drift of activation magnitudes that leads to overflow or underflow. By normalizing activations to zero mean and unit variance, these layers allow higher learning rates and make training less sensitive to initialization. The epsilon parameter in the normalization formula (x_normalized = (x - mean) / sqrt(variance + epsilon)) prevents division by zero when the variance is very small.
Thorough data preprocessing eliminates NaN traps that originate from the input data.
torch.isnan() and torch.isinf() (or np.isnan() and np.isinf() in NumPy).Learning rate warmup starts training with a very small learning rate and gradually increases it over a specified number of steps. This allows the model to find a stable region of the loss landscape before applying large updates. Warmup is standard practice for training transformers and large language models, where the interaction between the attention mechanism and large initial gradients can trigger NaN within the first few hundred steps without warmup.
Learning rate scheduling (cosine decay, step decay, or adaptive methods like Adam, AdaGrad, and RMSProp) prevents the learning rate from remaining too high as training progresses and the model approaches a minimum.
When NaN values appear during training, systematic debugging is needed to locate the root cause.
The first line of defense is monitoring the loss and gradient norm at every training step. A sudden spike in gradient norm (for example, jumping from 1.0 to 10^4 in a single step) usually precedes a NaN loss by one or two steps. Logging frameworks like Weights & Biases, TensorBoard, and MLflow can display these metrics in real time.
PyTorch provides torch.autograd.detect_anomaly(), a context manager that checks for NaN or infinity values at every backward-pass operation. When it detects an invalid value, it raises an error with a stack trace pointing to the exact operation that produced the NaN.
with torch.autograd.detect_anomaly():
output = model(input_data)
loss = criterion(output, target)
loss.backward()
This mode adds significant overhead and should only be used for debugging, not production training. Once the problematic operation is identified, the anomaly detection can be removed.
Inserting explicit NaN checks at strategic points in the forward pass can narrow down the source.
def forward(self, x):
x = self.layer1(x)
assert not torch.isnan(x).any(), "NaN detected after layer1"
x = self.activation(x)
assert not torch.isnan(x).any(), "NaN detected after activation"
x = self.layer2(x)
assert not torch.isnan(x).any(), "NaN detected after layer2"
return x
In TensorFlow, the tf.debugging.enable_check_numerics() function adds runtime checks for NaN and infinity to all floating-point tensors. Like PyTorch's anomaly detection, this should be used only during debugging because of the performance cost.
Saving model checkpoints at regular intervals (for example, every 1,000 steps) allows training to resume from the last known good state when a NaN trap occurs. For large language model training, where a single run can cost millions of dollars in compute, checkpoint-based recovery is the standard approach. Research has shown that skipping 200 to 500 data batches from the point of the loss spike and resuming from a checkpoint roughly 100 steps before the spike often allows training to proceed past the problematic data.
The pre-training of large language models presents especially acute NaN trap challenges because of the scale involved. Models with billions of parameters trained on trillions of tokens run for weeks or months on thousands of GPUs, and a single NaN trap can waste days of compute time.
Several architectural modifications have been developed to improve stability:
Zhang et al. (2023) studied loss spikes in LLM pre-training and found that the fundamental condition for stability requires "small sub-layers and large shortcut," meaning the residual connections should dominate over the sublayer transformations in magnitude.
The following checklist provides a systematic approach to diagnosing and resolving NaN traps.
| Step | Action | What to look for |
|---|---|---|
| 1 | Check input data | NaN or infinity values in features or labels |
| 2 | Inspect loss value over time | Sudden spike or gradual increase before NaN |
| 3 | Monitor gradient norm | Exponential growth preceding the NaN |
| 4 | Reduce learning rate | Test whether a 10x smaller learning rate prevents NaN |
| 5 | Enable anomaly detection | Use torch.autograd.detect_anomaly() or tf.debugging.enable_check_numerics() |
| 6 | Check for division by zero | Look for normalization layers, custom losses, or divisions in the model |
| 7 | Verify loss function inputs | Ensure log() never receives zero; ensure softmax inputs are finite |
| 8 | Try FP32 training | If using mixed precision, switch to FP32 to isolate precision-related issues |
| 9 | Add gradient clipping | Set max_norm=1.0 and see if NaN disappears |
| 10 | Review initialization | Confirm the right scheme is used for the activation function |
| Unsafe operation | When it fails | NaN-safe alternative |
|---|---|---|
log(x) | x = 0 | log(x + epsilon) |
x / y | y = 0 | x / (y + epsilon) or check y != 0 |
sqrt(x) | x < 0 (floating-point rounding) | sqrt(clamp(x, min=0)) or sqrt(x + epsilon) |
exp(x) | x is large (overflow) | exp(clamp(x, max=88)) for FP32 |
softmax(x) | x contains very large values | Subtract max(x) before exponentiating |
x / std(x) | std = 0 (constant input) | x / (std(x) + epsilon) |
pow(x, n) | x is large, n > 1 | Clamp x or use log-space: exp(n * log(x)) |
1 / (1 + exp(-x)) (sigmoid) | x is very negative (underflow) | Use framework's built-in sigmoid function |
Imagine you are stacking blocks to build a tower. Each block sits on the one below it, and you have to do a little math to figure out where each block goes. Now, what happens if you try to divide by zero? Your calculator shows "Error," right? In computers, instead of showing "Error," the calculator writes a special answer called "NaN" (which stands for "Not a Number").
Here is the tricky part. If you use that "NaN" answer in your next math problem, the answer is also NaN. And if you use that answer in the next problem, that is NaN too. It is like getting paint on your hands; everything you touch gets paint on it. Pretty soon, your whole tower of blocks is covered in paint and you cannot use any of them.
That is what happens in a computer when it trains a neural network. Millions of math problems are solved in a chain, one after another. If just one of those problems gives a NaN answer, the whole chain gets ruined. The way to fix it is to be careful about your math (do not divide by zero, do not let numbers get too big) and to check for NaN early so you can stop and fix the problem before it spreads everywhere.