# Metrics API (tf.metrics)

> Source: https://aiwiki.ai/wiki/metrics_api_tf_metrics
> Updated: 2026-05-11
> Categories: Machine Learning
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

*See also: [Machine learning terms](/wiki/machine_learning_terms)*

## Overview

The **Metrics API** in [TensorFlow](/wiki/tensorflow) is a collection of classes and utilities for computing evaluation scores that summarize how well a model is doing. The current API lives under `tf.keras.metrics` (and equivalently `keras.metrics` after the Keras 3 rewrite) and replaces the older graph-mode `tf.metrics` module from TensorFlow 1.x. The newer API is the recommended one for any modern code; the legacy `tf.compat.v1.metrics` namespace exists only for backward compatibility with TF 1.x training loops.

Metrics differ from loss functions. A loss feeds into the optimizer and shapes the gradients. A metric is only read, never differentiated through. You can use almost any loss as a metric, but not the other way around: some metrics (Precision, Recall, AUC, F1) are not differentiable and cannot be optimized directly.

## Evaluation metrics

Evaluation metrics are quantitative measures used to assess the performance of machine learning models. They provide a means to compare the effectiveness of different models, fine-tune model parameters, and monitor the training process. Some commonly used evaluation metrics include:

### Classification metrics

- **Accuracy**: The proportion of correctly classified instances out of the total instances in a dataset.
- **Precision**: The proportion of true positive instances among the instances predicted as positive by the model.
- **Recall**: The proportion of true positive instances among the actual positive instances in the dataset.
- **F1-score**: The harmonic mean of precision and recall, which considers both false positives and false negatives.
- **AUC-ROC**: The area under the Receiver Operating Characteristic ([ROC](/wiki/roc_curve)) curve, which plots the true positive rate against the false positive rate at various classification thresholds.

### Regression metrics

- **Mean Absolute Error (MAE)**: The average of the absolute differences between the predicted and actual values.
- **Mean Squared Error ([MSE](/wiki/mean_squared_error))**: The average of the squared differences between predicted and actual values.
- **Root Mean Squared Error (RMSE)**: The square root of the MSE, which represents the standard deviation of the residuals.
- **R-squared**: The proportion of the variance in the dependent variable that is predictable from the independent variables, indicating how well the model fits the data.

## TensorFlow Metrics API (tf.metrics)

The TensorFlow Metrics API provides a collection of pre-built metrics, as well as the ability to create custom metrics for specific use cases. The API is designed to be consistent and easy to use, with each metric represented as a class inheriting from the base class **tf.keras.metrics.Metric**. The primary methods provided by this base class are:

- **update_state(y_true, y_pred, sample_weight=None)**: Processes a batch of labels and predictions and accumulates them into internal weight variables. It returns nothing; its job is to mutate state.
- **result()**: Reads the accumulated state and returns the current scalar value of the metric as a `tf.Tensor`.
- **reset_state()**: Sets all state variables back to zero. In TF 2.x training loops Keras calls this automatically at the start of every epoch, so the metric reported for one epoch is not contaminated by the previous one. The older alias `reset_states()` is still accepted.

The Metrics API allows users to monitor model performance during training, validation, and testing, and can be integrated with the TensorFlow Estimators API or the Keras API ([Keras](/wiki/keras)).

## Streaming and stateful behaviour

The defining characteristic of `tf.keras.metrics` is that every metric is streaming, also called stateful. Instead of buffering every prediction in memory and computing a value at the end, each metric stores a small set of running statistics in variables created with `self.add_weight()`. Those variables (typically a `total` and a `count`, or a confusion-matrix tensor, or a histogram of thresholded counts) are updated batch by batch and the result is read out on demand.

This matters more than it first looks. The naive alternative, averaging the per-batch value of a metric, gives the wrong answer for non-linear metrics. The mean of per-batch AUC scores is not the AUC of the full dataset; the mean of per-batch Precision is biased when class balance varies across batches. A streaming implementation accumulates the underlying counts (true positives, false positives, threshold histograms) and recomputes the metric from those counts whenever `result()` is called, so the value at the end of an epoch is mathematically the value over the whole epoch.

The streaming design also plays well with distributed training. In multi-GPU or TPU setups, the per-replica state variables are merged via `merge_state()` before `result()` is computed, so a value gathered across eight accelerators is identical to the value you would have gotten on a single device.

## Available built-in metrics

The `tf.keras.metrics` module ships with dozens of pre-built classes. The most commonly used ones fall into a handful of buckets.

| Group | Class | Use case |
|---|---|---|
| Accuracy | `Accuracy` | Generic equality between `y_true` and `y_pred` |
| Accuracy | `BinaryAccuracy` | Binary labels, threshold defaults to 0.5 |
| Accuracy | `CategoricalAccuracy` | One-hot labels, softmax predictions |
| Accuracy | `SparseCategoricalAccuracy` | Integer labels, softmax predictions |
| Accuracy | `TopKCategoricalAccuracy` | One-hot labels, correct if true class is in the top K |
| Accuracy | `SparseTopKCategoricalAccuracy` | Integer labels, top-K variant |
| Classification | `Precision`, `Recall` | Threshold-based precision and recall |
| Classification | `TruePositives`, `TrueNegatives`, `FalsePositives`, `FalseNegatives` | Confusion-matrix cells |
| Classification | `AUC` | Riemann-sum approximation of ROC or PR area |
| Classification | `F1Score`, `FBetaScore` | Harmonic mean of precision and recall |
| Classification | `PrecisionAtRecall`, `RecallAtPrecision`, `SensitivityAtSpecificity`, `SpecificityAtSensitivity` | Operating-point metrics |
| Probabilistic | `BinaryCrossentropy`, `CategoricalCrossentropy`, `SparseCategoricalCrossentropy`, `KLDivergence`, `Poisson` | Log-likelihood style metrics that double as losses |
| Regression | `MeanSquaredError`, `MeanAbsoluteError`, `RootMeanSquaredError`, `MeanAbsolutePercentageError`, `MeanSquaredLogarithmicError`, `LogCoshError`, `CosineSimilarity`, `R2Score` | Numerical regression |
| Segmentation | `IoU`, `BinaryIoU`, `MeanIoU`, `OneHotIoU`, `OneHotMeanIoU` | Per-pixel overlap for semantic [segmentation](/wiki/image_segmentation) |
| Hinge | `Hinge`, `SquaredHinge`, `CategoricalHinge` | SVM-style margin metrics |
| Wrappers | `Mean`, `Sum`, `MeanMetricWrapper`, `MeanTensor` | Generic reduction helpers |

The wrapper classes are quietly the most important. `Mean` keeps a `total` and a `count` and lets you compute the running average of any tensor you push into it, which is how loss values are reported during training. `MeanMetricWrapper` turns any stateless function with the signature `fn(y_true, y_pred)` into a streaming metric by passing the per-sample values through `Mean`. Most of the regression metrics in the table above are implemented as thin `MeanMetricWrapper` subclasses around the matching `tf.keras.losses` function.

A few classes deserve a closer look:

- **AUC** discretizes the ROC or PR curve into `num_thresholds` evenly spaced thresholds (200 by default) and accumulates per-threshold counts of true positives, false positives, true negatives and false negatives. The area is then computed by a Riemann sum with `summation_method='interpolation'` by default, which gives a smoother estimate than the simple step approximation. For severely imbalanced classification, the `curve='PR'` mode is usually more informative than the default ROC.
- **MeanIoU** maintains a full `num_classes` by `num_classes` confusion matrix as its state. At each step the new predictions are flattened, compared against the labels, and added into the matrix; `result()` reads the diagonal and the row and column sums to produce the per-class IoU and averages them. An `ignore_class` argument lets you skip a void label (255 in Cityscapes, for example) during accumulation.
- **Precision** and **Recall** can be configured with a list of thresholds, top-k cutoffs, and class IDs to restrict the count to a specific positive class. Each combination is accumulated independently inside the same metric object.

## Custom metrics

When the built-ins do not cover the case at hand, you subclass `tf.keras.metrics.Metric` directly. The pattern is:

```python
import tensorflow as tf

class BinaryTruePositives(tf.keras.metrics.Metric):
    def __init__(self, name='binary_true_positives', **kwargs):
        super().__init__(name=name, **kwargs)
        self.tp = self.add_weight(name='tp', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.bool)
        y_pred = tf.cast(y_pred, tf.bool)
        values = tf.logical_and(y_true, y_pred)
        values = tf.cast(values, self.dtype)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self.dtype)
            values = values * sample_weight
        self.tp.assign_add(tf.reduce_sum(values))

    def result(self):
        return self.tp

    def reset_state(self):
        self.tp.assign(0.0)
```

The key things to get right: state variables must be created with `self.add_weight()` so they participate in TensorFlow's variable bookkeeping and survive checkpointing; `update_state` must not return anything; `result` must return a scalar `tf.Tensor`; and `reset_state` must zero out every variable, otherwise metric values bleed between epochs. Marking the class with `@tf.keras.saving.register_keras_serializable()` makes the metric load cleanly when a saved model is restored.

## Use during training and evaluation

Metrics integrate with `model.compile()` either as a string shortcut or as instantiated objects:

```python
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[
        tf.keras.metrics.SparseCategoricalAccuracy(),
        tf.keras.metrics.AUC(curve='PR'),
    ],
)
```

During `model.fit()`, Keras calls `update_state` on every batch and `result` at the end of each epoch, then logs the value to the returned `History` object and to any TensorBoard callbacks. `model.evaluate()` performs the same accumulation on a separate dataset. Outside `fit`, you can drive a metric by hand with `update_state`, `result` and `reset_state`, which is what custom training loops written with `tf.GradientTape` do.

For multi-output models the `metrics` argument accepts a list of lists (one per output) or a dictionary keyed by output name, so different heads can be evaluated with different metrics.

## Common pitfalls

A few patterns recur in bug reports and forum threads:

- Forgetting `reset_state()` in a custom loop. The metric keeps accumulating across epochs and reports a slow-moving average that drifts toward the long-run mean.
- Using `tf.keras.metrics.Accuracy` for classification with softmax outputs. `Accuracy` compares tensors elementwise, so you need to argmax the predictions first, or use `SparseCategoricalAccuracy` or `CategoricalAccuracy`, which do the argmax internally.
- Reusing the same metric instance for both training and validation. The state is shared, so the running value mixes the two phases. Allocate one instance per evaluation context.
- Treating `AUC` as an exact integral. For small `num_thresholds` and skewed score distributions the approximation error can be visible. Bumping `num_thresholds` to 500 or 1000 usually closes the gap with scikit-learn's `roc_auc_score`.

## tf.metrics vs tf.compat.v1.metrics

In TensorFlow 1.x, `tf.metrics` was a module of functions, not classes. Each function returned a tuple of `(value_tensor, update_op)` that had to be run in a `tf.Session`, and the metric state lived in local variables that had to be initialized explicitly. The function-and-tuple style was clean inside static graphs but it sits awkwardly next to eager execution.

The TF 2 migration moved the entire surface to `tf.keras.metrics`, where each metric is a Python object that owns its state. The TF 1.x functions still exist under `tf.compat.v1.metrics` for older code; using them in new projects is discouraged because they do not integrate with `model.fit()`, eager mode debugging or the modern distribution strategies.

## Keras 3 and the multi-backend story

With the release of Keras 3 (default in TensorFlow 2.16, March 2024), the same metric classes can run on top of TensorFlow, JAX or PyTorch. The API surface is unchanged for users; the implementation now dispatches through `keras.ops` instead of calling TensorFlow primitives directly. Code written against `tf.keras.metrics` continues to work, and the same metric subclass can run unmodified under any of the three supported backends as long as the `update_state` body avoids backend-specific calls.

## PyTorch equivalent: TorchMetrics

PyTorch does not ship a metrics library in its core distribution. The de-facto standard is [TorchMetrics](/wiki/torchmetrics), maintained by the PyTorch Lightning team. Its design mirrors the Keras one closely: a `Metric` base class with state variables registered via `add_state`, an `update()` method for per-batch accumulation, a `compute()` method for the final value, and a `reset()` method to zero the state between epochs. The library ships more than 100 metric implementations covering classification, regression, retrieval, image quality (SSIM, PSNR, FID, LPIPS) and natural language tasks (BLEU, ROUGE, BERTScore, Word Error Rate), plus a functional interface for one-shot stateless computation.

The biggest practical difference is distributed-training support. TorchMetrics has explicit knobs (`sync_on_compute`, `dist_sync_fn`, `process_group`) for how state should be merged across processes, and uses `torch.distributed.all_gather` by default. The Keras equivalent (`merge_state`) is invoked automatically by the framework and gives the user less direct control.

## Explain like I'm 5 (ELI5)

Imagine you are playing a game where you need to guess the color of hidden cards. To know how good you are at guessing, you need a way to measure your performance. In machine learning, we have something similar called metrics that help us measure how well our computer programs (models) are doing their jobs.

The Metrics API in TensorFlow is like a toolbox for those measurements. It has a small machine inside each tool that keeps a running tally. Every time you make a guess, it adds the result to the tally instead of forgetting it. At the end of the round it tells you the score for everything you guessed so far, then wipes the tally clean for the next round.

There are different tools for different games. One tool counts how many guesses you got exactly right, another rewards you for being close, and another one checks whether the right answer was at least in your top three guesses. You can also build your own tool if none of the ready-made ones fits.

## References

- [tf.keras.metrics module reference (TensorFlow v2.16.1)](https://www.tensorflow.org/api_docs/python/tf/keras/metrics)
- [tf.keras.metrics.Metric base class](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric)
- [tf.keras.metrics.AUC](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/AUC)
- [tf.keras.metrics.Precision](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision)
- [tf.keras.metrics.MeanIoU](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/MeanIoU)
- [Keras 3 metrics API documentation](https://keras.io/api/metrics/)
- [Keras accuracy metrics guide](https://keras.io/api/metrics/accuracy_metrics/)
- [Training with built-in methods (TensorFlow guide)](https://www.tensorflow.org/guide/keras/training_with_built_in_methods)
- [Migrate metrics and optimizers (TF1 to TF2)](https://www.tensorflow.org/guide/migrate/metrics_optimizers)
- [TorchMetrics documentation](https://lightning.ai/docs/torchmetrics/stable/)

