Metrics API (tf.metrics)
Last reviewed
May 11, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v2 ยท 2,200 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 11, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v2 ยท 2,200 words
Add missing citations, update stale details, or suggest a clearer explanation.
See also: Machine learning terms
The Metrics API in TensorFlow is a collection of classes and utilities for computing evaluation scores that summarize how well a model is doing. The current API lives under tf.keras.metrics (and equivalently keras.metrics after the Keras 3 rewrite) and replaces the older graph-mode tf.metrics module from TensorFlow 1.x. The newer API is the recommended one for any modern code; the legacy tf.compat.v1.metrics namespace exists only for backward compatibility with TF 1.x training loops.
Metrics differ from loss functions. A loss feeds into the optimizer and shapes the gradients. A metric is only read, never differentiated through. You can use almost any loss as a metric, but not the other way around: some metrics (Precision, Recall, AUC, F1) are not differentiable and cannot be optimized directly.
Evaluation metrics are quantitative measures used to assess the performance of machine learning models. They provide a means to compare the effectiveness of different models, fine-tune model parameters, and monitor the training process. Some commonly used evaluation metrics include:
The TensorFlow Metrics API provides a collection of pre-built metrics, as well as the ability to create custom metrics for specific use cases. The API is designed to be consistent and easy to use, with each metric represented as a class inheriting from the base class tf.keras.metrics.Metric. The primary methods provided by this base class are:
tf.Tensor.reset_states() is still accepted.The Metrics API allows users to monitor model performance during training, validation, and testing, and can be integrated with the TensorFlow Estimators API or the Keras API (Keras).
The defining characteristic of tf.keras.metrics is that every metric is streaming, also called stateful. Instead of buffering every prediction in memory and computing a value at the end, each metric stores a small set of running statistics in variables created with self.add_weight(). Those variables (typically a total and a count, or a confusion-matrix tensor, or a histogram of thresholded counts) are updated batch by batch and the result is read out on demand.
This matters more than it first looks. The naive alternative, averaging the per-batch value of a metric, gives the wrong answer for non-linear metrics. The mean of per-batch AUC scores is not the AUC of the full dataset; the mean of per-batch Precision is biased when class balance varies across batches. A streaming implementation accumulates the underlying counts (true positives, false positives, threshold histograms) and recomputes the metric from those counts whenever result() is called, so the value at the end of an epoch is mathematically the value over the whole epoch.
The streaming design also plays well with distributed training. In multi-GPU or TPU setups, the per-replica state variables are merged via merge_state() before result() is computed, so a value gathered across eight accelerators is identical to the value you would have gotten on a single device.
The tf.keras.metrics module ships with dozens of pre-built classes. The most commonly used ones fall into a handful of buckets.
| Group | Class | Use case |
|---|---|---|
| Accuracy | Accuracy | Generic equality between y_true and y_pred |
| Accuracy | BinaryAccuracy | Binary labels, threshold defaults to 0.5 |
| Accuracy | CategoricalAccuracy | One-hot labels, softmax predictions |
| Accuracy | SparseCategoricalAccuracy | Integer labels, softmax predictions |
| Accuracy | TopKCategoricalAccuracy | One-hot labels, correct if true class is in the top K |
| Accuracy | SparseTopKCategoricalAccuracy | Integer labels, top-K variant |
| Classification | Precision, Recall | Threshold-based precision and recall |
| Classification | TruePositives, TrueNegatives, FalsePositives, FalseNegatives | Confusion-matrix cells |
| Classification | AUC | Riemann-sum approximation of ROC or PR area |
| Classification | F1Score, FBetaScore | Harmonic mean of precision and recall |
| Classification | PrecisionAtRecall, RecallAtPrecision, SensitivityAtSpecificity, SpecificityAtSensitivity | Operating-point metrics |
| Probabilistic | BinaryCrossentropy, CategoricalCrossentropy, SparseCategoricalCrossentropy, KLDivergence, Poisson | Log-likelihood style metrics that double as losses |
| Regression | MeanSquaredError, MeanAbsoluteError, RootMeanSquaredError, MeanAbsolutePercentageError, MeanSquaredLogarithmicError, LogCoshError, CosineSimilarity, R2Score | Numerical regression |
| Segmentation | IoU, BinaryIoU, MeanIoU, OneHotIoU, OneHotMeanIoU | Per-pixel overlap for semantic segmentation |
| Hinge | Hinge, SquaredHinge, CategoricalHinge | SVM-style margin metrics |
| Wrappers | Mean, Sum, MeanMetricWrapper, MeanTensor | Generic reduction helpers |
The wrapper classes are quietly the most important. Mean keeps a total and a count and lets you compute the running average of any tensor you push into it, which is how loss values are reported during training. MeanMetricWrapper turns any stateless function with the signature fn(y_true, y_pred) into a streaming metric by passing the per-sample values through Mean. Most of the regression metrics in the table above are implemented as thin MeanMetricWrapper subclasses around the matching tf.keras.losses function.
A few classes deserve a closer look:
num_thresholds evenly spaced thresholds (200 by default) and accumulates per-threshold counts of true positives, false positives, true negatives and false negatives. The area is then computed by a Riemann sum with summation_method='interpolation' by default, which gives a smoother estimate than the simple step approximation. For severely imbalanced classification, the curve='PR' mode is usually more informative than the default ROC.num_classes by num_classes confusion matrix as its state. At each step the new predictions are flattened, compared against the labels, and added into the matrix; result() reads the diagonal and the row and column sums to produce the per-class IoU and averages them. An ignore_class argument lets you skip a void label (255 in Cityscapes, for example) during accumulation.When the built-ins do not cover the case at hand, you subclass tf.keras.metrics.Metric directly. The pattern is:
import tensorflow as tf
class BinaryTruePositives(tf.keras.metrics.Metric):
def __init__(self, name='binary_true_positives', **kwargs):
super().__init__(name=name, **kwargs)
self.tp = self.add_weight(name='tp', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, tf.bool)
y_pred = tf.cast(y_pred, tf.bool)
values = tf.logical_and(y_true, y_pred)
values = tf.cast(values, self.dtype)
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self.dtype)
values = values * sample_weight
self.tp.assign_add(tf.reduce_sum(values))
def result(self):
return self.tp
def reset_state(self):
self.tp.assign(0.0)
The key things to get right: state variables must be created with self.add_weight() so they participate in TensorFlow's variable bookkeeping and survive checkpointing; update_state must not return anything; result must return a scalar tf.Tensor; and reset_state must zero out every variable, otherwise metric values bleed between epochs. Marking the class with @tf.keras.saving.register_keras_serializable() makes the metric load cleanly when a saved model is restored.
Metrics integrate with model.compile() either as a string shortcut or as instantiated objects:
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
tf.keras.metrics.SparseCategoricalAccuracy(),
tf.keras.metrics.AUC(curve='PR'),
],
)
During model.fit(), Keras calls update_state on every batch and result at the end of each epoch, then logs the value to the returned History object and to any TensorBoard callbacks. model.evaluate() performs the same accumulation on a separate dataset. Outside fit, you can drive a metric by hand with update_state, result and reset_state, which is what custom training loops written with tf.GradientTape do.
For multi-output models the metrics argument accepts a list of lists (one per output) or a dictionary keyed by output name, so different heads can be evaluated with different metrics.
A few patterns recur in bug reports and forum threads:
reset_state() in a custom loop. The metric keeps accumulating across epochs and reports a slow-moving average that drifts toward the long-run mean.tf.keras.metrics.Accuracy for classification with softmax outputs. Accuracy compares tensors elementwise, so you need to argmax the predictions first, or use SparseCategoricalAccuracy or CategoricalAccuracy, which do the argmax internally.AUC as an exact integral. For small num_thresholds and skewed score distributions the approximation error can be visible. Bumping num_thresholds to 500 or 1000 usually closes the gap with scikit-learn's roc_auc_score.In TensorFlow 1.x, tf.metrics was a module of functions, not classes. Each function returned a tuple of (value_tensor, update_op) that had to be run in a tf.Session, and the metric state lived in local variables that had to be initialized explicitly. The function-and-tuple style was clean inside static graphs but it sits awkwardly next to eager execution.
The TF 2 migration moved the entire surface to tf.keras.metrics, where each metric is a Python object that owns its state. The TF 1.x functions still exist under tf.compat.v1.metrics for older code; using them in new projects is discouraged because they do not integrate with model.fit(), eager mode debugging or the modern distribution strategies.
With the release of Keras 3 (default in TensorFlow 2.16, March 2024), the same metric classes can run on top of TensorFlow, JAX or PyTorch. The API surface is unchanged for users; the implementation now dispatches through keras.ops instead of calling TensorFlow primitives directly. Code written against tf.keras.metrics continues to work, and the same metric subclass can run unmodified under any of the three supported backends as long as the update_state body avoids backend-specific calls.
PyTorch does not ship a metrics library in its core distribution. The de-facto standard is TorchMetrics, maintained by the PyTorch Lightning team. Its design mirrors the Keras one closely: a Metric base class with state variables registered via add_state, an update() method for per-batch accumulation, a compute() method for the final value, and a reset() method to zero the state between epochs. The library ships more than 100 metric implementations covering classification, regression, retrieval, image quality (SSIM, PSNR, FID, LPIPS) and natural language tasks (BLEU, ROUGE, BERTScore, Word Error Rate), plus a functional interface for one-shot stateless computation.
The biggest practical difference is distributed-training support. TorchMetrics has explicit knobs (sync_on_compute, dist_sync_fn, process_group) for how state should be merged across processes, and uses torch.distributed.all_gather by default. The Keras equivalent (merge_state) is invoked automatically by the framework and gives the user less direct control.
Imagine you are playing a game where you need to guess the color of hidden cards. To know how good you are at guessing, you need a way to measure your performance. In machine learning, we have something similar called metrics that help us measure how well our computer programs (models) are doing their jobs.
The Metrics API in TensorFlow is like a toolbox for those measurements. It has a small machine inside each tool that keeps a running tally. Every time you make a guess, it adds the result to the tally instead of forgetting it. At the end of the round it tells you the score for everything you guessed so far, then wipes the tally clean for the next round.
There are different tools for different games. One tool counts how many guesses you got exactly right, another rewards you for being close, and another one checks whether the right answer was at least in your top three guesses. You can also build your own tool if none of the ready-made ones fits.