# tf.keras

> Source: https://aiwiki.ai/wiki/tf_keras
> Updated: 2026-06-27
> Categories: Deep Learning, Developer Tools, Machine Learning, Neural Networks, Programming Languages
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

**tf.keras** is the high-level [deep learning](/wiki/deep_learning) API built directly into the [TensorFlow](/wiki/tensorflow) [machine learning](/wiki/machine_learning) framework, and it has been TensorFlow's official and recommended high-level API since the release of TensorFlow 2.0 on September 30, 2019. [1][6] It is TensorFlow's packaging of [Keras](/wiki/keras), the user-facing [neural network](/wiki/neural_network) API originally created by Francois Chollet and first released on March 27, 2015. [2][13] The TensorFlow documentation describes it plainly: "Keras is the high-level API of the TensorFlow platform. It provides an approachable, highly-productive interface for solving machine learning (ML) problems, with a focus on modern deep learning." [6]

tf.keras lets developers build, train, evaluate, and deploy deep learning models using a small set of composable objects (layers, models, optimizers, loss functions, metrics, and callbacks) instead of the verbose graph-and-session code that earlier versions of TensorFlow required. [5] It emphasizes rapid prototyping, modularity, and extensibility, and supports a range of model-building paradigms, from simple sequential stacks of layers to highly customized architectures defined with arbitrary Python code. The API handles the full lifecycle of a project: data preprocessing, model construction, training with built-in or custom loops, evaluation, prediction, model saving, and deployment. Since TensorFlow 2.16 (March 2024), tf.keras points to the multi-backend Keras 3 by default, which means the same Keras code can also run on [JAX](/wiki/jax) and [PyTorch](/wiki/pytorch). [13][14]

## Explain like I'm 5 (ELI5)

Imagine you want to build a robot that can recognize pictures of cats and dogs. Building that robot from scratch would be really hard because you would need to know how every tiny part works. tf.keras is like a box of special building blocks that snap together easily. You pick the blocks you need, connect them in a line or a pattern, show the robot lots of pictures so it can learn, and then ask it to guess what new pictures are. The building blocks handle all the complicated math inside so you can focus on what you want your robot to learn.

## What is tf.keras?

tf.keras is the implementation of the Keras API that ships inside the TensorFlow package, importable as `tf.keras` (or `from tensorflow import keras`). When Google announced TensorFlow 2.0 in 2019, it designated Keras as "the official high-level API of TensorFlow" and made tf.keras the front door through which most users interact with the framework. [5] Before that, building a model in TensorFlow typically meant manually constructing a static computation graph and running it inside a session. tf.keras replaced that workflow with an object-oriented, [Python](/wiki/python)-native interface in which layers are objects you stack or connect, and a model is trained with a single call to `model.fit()`. [8]

Keras itself predates TensorFlow. Francois Chollet released it in March 2015 as a high-level library that originally ran on the Theano backend, with the tagline "Deep Learning for humans." [2][13] tf.keras is therefore best understood not as a separate library but as TensorFlow's tightly integrated, fully supported edition of Keras, with deep hooks into TensorFlow features such as [eager execution](/wiki/eager_execution), `tf.data` input pipelines, `tf.distribute` distributed training, TPUs, and the SavedModel export format used by TensorFlow Serving, TensorFlow Lite, and TensorFlow.js. [6]

## What is the difference between tf.keras and Keras?

The relationship between "Keras" and "tf.keras" has changed over time, and the distinction depends on which era you are talking about.

- **Standalone Keras (2015-2019)** was a separate `pip install keras` package that ran on top of an interchangeable backend: first Theano, then TensorFlow (added after Google open-sourced TensorFlow in November 2015), and later the Microsoft Cognitive Toolkit (CNTK). [2][13]
- **tf.keras (2017 onward)** is the copy of the Keras API bundled inside TensorFlow itself. It was introduced around TensorFlow 1.4 in 2017 and promoted to the official high-level API in TensorFlow 2.0. [1][5] During the TensorFlow 2.x era (Keras 2.x), tf.keras supported only the TensorFlow backend and added TensorFlow-specific capabilities such as eager execution, distribution strategies, and TPU training. tf.keras in TensorFlow 2.0 implemented the Keras 2.3.0 API, so for most code, switching from standalone Keras to tf.keras was as simple as changing the import statement.
- **Keras 3 (2023 onward)** reunified the two. Keras 3 is once again a multi-backend framework, and from TensorFlow 2.16 (March 2024) the name `tf.keras` resolves to Keras 3. [13][14] The practical effect is that tf.keras and standalone `keras` now refer to the same modern codebase; the main difference is whether you import it through TensorFlow or install the standalone `keras` package directly.

In short: tf.keras is TensorFlow's built-in distribution of Keras. The biggest functional difference from plain, low-level TensorFlow is the level of abstraction. tf.keras gives you ready-made models, layers, and a `fit`/`evaluate`/`predict` training loop, whereas raw TensorFlow gives you the underlying tensors, automatic differentiation (`tf.GradientTape`), and operations that those high-level objects are built on. The two are designed to interoperate: you can drop low-level TensorFlow ops into a custom Keras layer or training step whenever the high-level API is not enough. [9]

## When did tf.keras become the official TensorFlow API?

The relationship between Keras and TensorFlow has evolved through several distinct phases.

| Period | Milestone |
|---|---|
| March 27, 2015 | Francois Chollet releases Keras as an open-source library supporting Theano as its backend [2] |
| November 2015 | Google open-sources TensorFlow; Keras is refactored to support TensorFlow as a backend [1] |
| 2016 | Keras gains support for the Microsoft Cognitive Toolkit (CNTK) backend |
| 2017 | Keras is integrated into TensorFlow as `tf.keras` (around TensorFlow 1.4), available alongside the standalone package [1] |
| September 30, 2019 | TensorFlow 2.0 launches with `tf.keras` as its official and recommended high-level API [5] |
| 2020-2023 | Keras versions 2.4 through 2.15 support only the TensorFlow backend |
| November 28, 2023 | Keras 3.0 launches as a multi-backend framework supporting TensorFlow, [JAX](/wiki/jax), and [PyTorch](/wiki/pytorch) [13] |
| March 2024 | TensorFlow 2.16 ships with Keras 3 as the default `tf.keras`; the legacy Keras 2 is published as the separate `tf-keras` PyPI package [14] |

Before TensorFlow 2.0, most TensorFlow code required verbose, low-level session management and graph construction. The adoption of Keras as the standard interface simplified the developer experience and lowered the barrier to entry for newcomers to deep learning. In December 2018, ahead of the 2.0 release, the TensorFlow team announced it would be "standardizing on Keras" as the framework's high-level API. [5]

## What are the core design principles of tf.keras?

tf.keras is built around four guiding principles that shape its API design. Chollet summarized the overarching philosophy as "progressive disclosure of complexity," meaning users can start with a few lines of high-level code and gradually drop down to lower-level control only when they need it. [13]

**User-friendliness.** The API minimizes the number of steps required to go from idea to working model. Error messages are clear, and the documentation provides extensive examples for common workflows.

**Modularity.** Models are built by connecting interchangeable components (layers, optimizers, loss functions, metrics, and callbacks). Each component has a consistent interface, making it straightforward to swap one for another.

**Extensibility.** When built-in components are insufficient, users can create custom layers, loss functions, metrics, callbacks, and even training loops by subclassing base classes and overriding specific methods.

**Python-native.** Unlike earlier TensorFlow APIs that relied on declarative graph definitions, tf.keras uses standard Python control flow. Combined with TensorFlow 2's default [eager execution](/wiki/eager_execution), this makes debugging intuitive because operations execute immediately and return concrete values. [6]

## How do you build a model with tf.keras?

tf.keras offers three progressively more flexible ways to define a model. [7]

### Sequential API

The Sequential API is the simplest approach. It creates a model as a linear stack of layers where each layer has exactly one input and one output. It is best suited for straightforward feedforward networks.

```python
import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])
```

The Sequential API cannot represent models that share layers, have multiple inputs or outputs, or contain branches.

### Functional API

The Functional API treats layers as functions that operate on tensor objects. This allows construction of models with non-linear topologies, shared layers, and multiple inputs or outputs. [7]

```python
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(128, activation='relu')(inputs)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
```

The Functional API is the most commonly used approach in practice because it balances flexibility with clarity. The resulting model object retains full knowledge of its graph topology, which enables features like model visualization, feature extraction from intermediate layers, and transfer learning.

### Model subclassing

Model subclassing provides maximum flexibility by letting developers define the forward pass using arbitrary Python code. The user subclasses `tf.keras.Model` and implements the `__init__` method (to define layers) and the `call` method (to specify the forward computation).

```python
class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(128, activation='relu')
        self.dropout = tf.keras.layers.Dropout(0.2)
        self.dense2 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.dropout(x, training=training)
        return self.dense2(x)
```

This approach is popular in research settings where non-standard architectures or dynamic computation graphs are needed. The tradeoff is that subclassed models sacrifice some of the introspection features available to Sequential and Functional models (for example, the ability to serialize the model architecture to JSON).

### Comparison of model-building APIs

| Feature | Sequential | Functional | Subclassing |
|---|---|---|---|
| Multiple inputs/outputs | No | Yes | Yes |
| Shared layers | No | Yes | Yes |
| Non-linear topology | No | Yes | Yes |
| Dynamic control flow in forward pass | No | No | Yes |
| Model graph introspection | Yes | Yes | Limited |
| Serialization to JSON | Yes | Yes | No |
| Ease of use | Highest | High | Moderate |
| Typical use case | Simple feedforward nets | Most production models | Research prototypes |

## What layers does tf.keras provide?

Layers are the fundamental building blocks of tf.keras models. Each layer receives input tensors, applies a transformation, and produces output tensors. tf.keras provides a large library of built-in layers organized by category.

| Category | Example layers | Typical use |
|---|---|---|
| Core | [Dense](/wiki/dense_layer), Activation, Embedding, Masking | General-purpose transformations |
| [Convolutional](/wiki/convolutional_neural_network) | Conv1D, Conv2D, Conv3D, SeparableConv2D, DepthwiseConv2D | Image and signal processing |
| Pooling | MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D | Spatial downsampling |
| [Recurrent](/wiki/recurrent_neural_network) | SimpleRNN, [LSTM](/wiki/long_short-term_memory_lstm), [GRU](/wiki/recurrent_neural_network), Bidirectional | Sequential and time-series data |
| Normalization | [BatchNormalization](/wiki/batch_normalization), LayerNormalization, GroupNormalization | Stabilizing and accelerating training |
| Regularization | [Dropout](/wiki/dropout), SpatialDropout1D, GaussianNoise | Reducing [overfitting](/wiki/overfitting) |
| [Attention](/wiki/attention) | MultiHeadAttention, Attention, AdditiveAttention | [Transformer](/wiki/transformer)-style models |
| Reshaping | Flatten, Reshape, Permute, RepeatVector | Tensor shape manipulation |
| Merging | Concatenate, Add, Multiply, Average | Combining multiple inputs |
| Preprocessing | TextVectorization, Normalization, Discretization, CategoryEncoding, Resizing, Rescaling | Data preprocessing within the model |

Users can define custom layers by subclassing `tf.keras.layers.Layer` and implementing the `build` and `call` methods. Custom layers can maintain trainable weights, support serialization, and integrate seamlessly with any model-building API.

## How do you train a model in tf.keras?

The standard tf.keras training workflow follows a compile-fit-evaluate-predict pattern. [8]

### Compile

The `compile` method configures the model for training by specifying the [optimizer](/wiki/optimizer), [loss function](/wiki/loss_function), and metrics.

```python
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
```

### Fit

The `fit` method trains the model for a specified number of [epochs](/wiki/epoch). It accepts NumPy arrays, `tf.data.Dataset` objects, or Python generators as input. It returns a `History` object containing per-epoch training and validation metrics.

```python
history = model.fit(
    x_train, y_train,
    epochs=10,
    batch_size=32,
    validation_split=0.2
)
```

### Evaluate

The `evaluate` method computes loss and metric values on a held-out test set.

```python
test_loss, test_acc = model.evaluate(x_test, y_test)
```

### Predict

The `predict` method generates output predictions for new input samples, processing them in batches.

```python
predictions = model.predict(x_new)
```

## How do custom training loops work with GradientTape?

For cases that require full control over the training process, tf.keras supports custom training loops using `tf.GradientTape`. This is useful for implementing advanced techniques such as custom [gradient](/wiki/gradient_descent) clipping, multi-optimizer training, or adversarial training. [9]

```python
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

for epoch in range(num_epochs):
    for x_batch, y_batch in train_dataset:
        with tf.GradientTape() as tape:
            predictions = model(x_batch, training=True)
            loss = loss_fn(y_batch, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
```

The `@tf.function` decorator can be applied to the training step function to compile it into a TensorFlow graph for improved performance, while still allowing development and debugging in eager mode.

## What optimizers does tf.keras include?

tf.keras includes a suite of gradient-based [optimizers](/wiki/optimizer) for updating model weights during training.

| Optimizer | Description | Typical use case |
|---|---|---|
| [SGD](/wiki/stochastic_gradient_descent_sgd) | Classic stochastic gradient descent with optional momentum and Nesterov acceleration | Baseline experiments, fine-tuning |
| [Adam](/wiki/adam_optimizer) | Adaptive learning rates using first and second moment estimates of gradients | General-purpose default optimizer |
| RMSprop | Divides learning rate by a moving average of squared gradients | Recurrent networks |
| AdaGrad | Adapts learning rate per parameter based on historical gradient magnitudes | Sparse data problems |
| AdamW | Adam with decoupled [weight decay](/wiki/weight_decay) regularization | Transformer-based models |
| Nadam | Adam with Nesterov momentum | Tasks benefiting from look-ahead gradients |
| Adamax | Variant of Adam using the infinity norm | Problems with very sparse gradients |

All optimizers accept a `learning_rate` parameter that can be a static float or a `tf.keras.optimizers.schedules.LearningRateSchedule` instance (such as `ExponentialDecay`, `CosineDecay`, or `PolynomialDecay`) for automated learning rate adjustment during training.

## What loss functions are available in tf.keras?

tf.keras provides loss functions for common supervised and unsupervised learning tasks.

| Loss function | Task type | Description |
|---|---|---|
| MeanSquaredError | [Regression](/wiki/regression) | Average of squared differences between predictions and targets |
| MeanAbsoluteError | Regression | Average of absolute differences |
| BinaryCrossentropy | Binary [classification](/wiki/classification) | Log loss for two-class problems |
| CategoricalCrossentropy | Multi-class classification | Log loss for one-hot encoded labels |
| SparseCategoricalCrossentropy | Multi-class classification | Log loss for integer-encoded labels |
| Huber | Regression | Combines MSE and MAE; less sensitive to outliers |
| KLDivergence | Distribution matching | Measures divergence between two probability distributions |
| CosineSimilarity | Embedding learning | Measures cosine distance between prediction and target vectors |

Custom loss functions can be implemented as plain Python functions or by subclassing `tf.keras.losses.Loss`.

## What are tf.keras callbacks?

[Callbacks](/wiki/callback) are objects that execute actions at specific points during training (at the start or end of an epoch, before or after a batch, and so on). tf.keras includes several built-in callbacks.

| Callback | Purpose |
|---|---|
| ModelCheckpoint | Saves the model (or just its weights) periodically or when a monitored metric improves |
| EarlyStopping | Stops training when a monitored metric has not improved for a specified number of epochs (patience) |
| ReduceLROnPlateau | Reduces the learning rate by a factor when a metric has stopped improving |
| TensorBoard | Logs metrics, histograms, and graphs for visualization in TensorBoard |
| LearningRateScheduler | Adjusts the learning rate at the start of each epoch using a user-defined schedule function |
| CSVLogger | Streams epoch results to a CSV file |
| TerminateOnNaN | Stops training when a NaN loss is encountered |
| LambdaCallback | Wraps arbitrary functions for use as callbacks |

Multiple callbacks can be combined and passed to `model.fit()` as a list. Users can also create custom callbacks by subclassing `tf.keras.callbacks.Callback`.

## How does transfer learning work with tf.keras applications?

tf.keras provides a collection of pretrained [computer vision](/wiki/computer_vision) models through `tf.keras.applications`. These models have been trained on the [ImageNet](/wiki/imagenet) dataset and can be used for image classification, [feature extraction](/wiki/feature_extraction), or [fine-tuning](/wiki/fine-tuning) on custom datasets.

| Model family | Example variants | Parameters (approx.) | Top-1 accuracy on ImageNet |
|---|---|---|---|
| [VGG](/wiki/vgg) | VGG16, VGG19 | 138M (VGG16) | 71.3% (VGG16) |
| [ResNet](/wiki/resnet) | ResNet50, ResNet101, ResNet152 | 25.6M (ResNet50) | 74.9% (ResNet50) |
| InceptionV3 | InceptionV3 | 23.9M | 77.9% |
| [EfficientNet](/wiki/efficientnet) | EfficientNetB0 through B7 | 5.3M (B0) to 66M (B7) | 77.1% (B0) to 84.3% (B7) |
| MobileNet | MobileNetV2, MobileNetV3 | 3.4M (V2) | 71.3% (V2) |
| DenseNet | DenseNet121, DenseNet169, DenseNet201 | 8.1M (DenseNet121) | 75.0% (DenseNet121) |
| NASNet | NASNetMobile, NASNetLarge | 5.3M (Mobile) | 74.4% (Mobile) |

A typical transfer learning workflow involves loading a pretrained model without its classification head (by setting `include_top=False`), freezing the pretrained layers, adding new trainable layers for the target task, training the new layers, and optionally unfreezing some pretrained layers for fine-tuning.

```python
base_model = tf.keras.applications.EfficientNetB0(
    weights='imagenet', include_top=False, input_shape=(224, 224, 3)
)
base_model.trainable = False

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])
```

## How does tf.keras handle data input pipelines?

tf.keras integrates closely with the `tf.data.Dataset` API for building efficient input pipelines. Using `tf.data`, developers can load data from files, apply transformations, shuffle, batch, and prefetch data for optimal hardware utilization.

```python
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(buffer_size=1024).batch(32).prefetch(tf.data.AUTOTUNE)
model.fit(dataset, epochs=10)
```

### Preprocessing layers

tf.keras includes built-in preprocessing layers that can be embedded directly into the model graph. This ensures that preprocessing logic is packaged with the model at serving time.

| Layer | Function |
|---|---|
| TextVectorization | Converts raw strings into integer token sequences |
| Normalization | Applies feature-wise normalization (zero mean, unit variance) |
| Discretization | Bins continuous features into discrete integer ranges |
| CategoryEncoding | Converts integer categories into one-hot or multi-hot representations |
| Resizing | Resizes images to a target height and width |
| Rescaling | Scales pixel values (for example, from [0, 255] to [0, 1]) |
| RandomFlip | Randomly flips images horizontally or vertically during training |
| RandomRotation | Applies random rotation augmentation during training |
| RandomZoom | Applies random zoom augmentation during training |

Image augmentation layers are only active during training and are automatically bypassed during inference.

## How do you save and serialize a tf.keras model?

tf.keras supports multiple formats for saving and loading models. [12]

| Format | File extension | Contents | Notes |
|---|---|---|---|
| Keras native | `.keras` | Architecture, weights, optimizer state, training config | Recommended format for Keras 3; supports all model types |
| SavedModel | directory | Execution graph, weights, signatures | Used by TensorFlow Serving, TensorFlow Lite, and TensorFlow.js |
| HDF5 | `.h5` | Architecture, weights, optimizer state | Legacy format; does not support custom objects without extra configuration |
| Weights only | `.weights.h5` | Weights only | Useful when the architecture is defined in code |

The SavedModel format saves the execution graph, which means subclassed models can be loaded without access to the original source code. The HDF5 format uses the model configuration (architecture as JSON) and therefore requires the same class definitions to be available at load time.

Models saved in the SavedModel format can be converted for deployment on mobile and edge devices using TensorFlow Lite (`tf.lite.TFLiteConverter`) or in web browsers using TensorFlow.js.

## How does distributed training work in tf.keras?

tf.keras integrates with TensorFlow's distribution strategies (`tf.distribute`) to enable training across multiple GPUs, multiple machines, and TPUs with minimal code changes. [10]

| Strategy | Hardware | Approach |
|---|---|---|
| MirroredStrategy | Multiple GPUs on one machine | Synchronous training with all-reduce gradient aggregation |
| MultiWorkerMirroredStrategy | Multiple GPUs across multiple machines | Synchronous training with all-reduce across a network |
| TPUStrategy | Google TPU pods | Synchronous training optimized for TPU interconnects |
| ParameterServerStrategy | CPU/GPU cluster with parameter servers | Asynchronous or synchronous training with centralized variable storage |
| CentralStorageStrategy | Single machine with multiple GPUs | Variables stored on CPU, computations replicated on GPUs |

Using a distribution strategy typically requires only wrapping the model creation and compilation inside a `strategy.scope()` block.

```python
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = tf.keras.Sequential([...])
    model.compile(...)
model.fit(dataset, epochs=10)
```

## How does mixed precision training work in tf.keras?

tf.keras supports [mixed precision](/wiki/mixed_precision_training) training through `tf.keras.mixed_precision`. Mixed precision uses 16-bit floating-point (float16 or bfloat16) for most operations while keeping critical computations in 32-bit float, which can deliver significant performance improvements on supported hardware. [11]

On NVIDIA GPUs with Tensor Cores (compute capability 7.0 and above, such as the V100, A100, and RTX series), mixed precision can improve training throughput by over 3x. On Google TPUs, using bfloat16 mixed precision can provide up to 60% throughput improvement. [11]

Enabling mixed precision in tf.keras requires a single line of code at the start of the program:

```python
tf.keras.mixed_precision.set_global_policy('mixed_float16')
```

For custom training loops, a `LossScaleOptimizer` handles dynamic loss scaling to prevent underflow in float16 gradients.

## What is the relationship between tf.keras and Keras 3?

With the release of Keras 3.0 on November 28, 2023, Keras became a multi-backend framework again, capable of running on TensorFlow, JAX, and PyTorch. [13] The Keras team described it as a full rewrite: "Keras 3 is a full rewrite of Keras that enables you to run your Keras workflows on top of either JAX, TensorFlow, PyTorch, or OpenVINO (for inference-only), and that unlocks brand new large-scale model training and deployment capabilities." [13] Starting with TensorFlow 2.16 (released March 2024), `tf.keras` points to Keras 3 by default. [14]

Keras 3 introduced several changes relevant to tf.keras users.

- **Multi-backend support.** Models written with Keras 3 can run on any supported backend without code changes. The backend is selected via the `KERAS_BACKEND` environment variable or `keras.config.set_backend()`. As the Keras team put it, "You can pick the framework that suits you best, and switch from one to another based on your current goals." [13]
- **Performance improvements.** Keras 3 benchmarks show consistent improvements over Keras 2, with some workloads seeing throughput gains exceeding 100%. [13]
- **OpenVINO backend.** In addition to TensorFlow, JAX, and PyTorch, Keras 3 added experimental support for Intel's OpenVINO as an inference backend. [13]
- **Legacy compatibility.** Users who need Keras 2 behavior can install the `tf-keras` package from PyPI and set the environment variable `TF_USE_LEGACY_KERAS=1`, which makes `tf.keras` resolve to the older Keras 2 codebase. [14]

The transition means that `tf.keras` is no longer a TensorFlow-only API but part of a broader cross-framework ecosystem, though TensorFlow remains a first-class supported backend.

## What is tf.keras used for?

tf.keras is used across a wide range of deep learning applications.

| Domain | Typical models | tf.keras components used |
|---|---|---|
| [Image classification](/wiki/image_classification_models) | [CNNs](/wiki/convolutional_neural_network), EfficientNet, ResNet | Conv2D, MaxPooling2D, Dense, tf.keras.applications |
| [Object detection](/wiki/object_detection) | Feature pyramid networks, SSD | Functional API, custom layers |
| [Natural language processing](/wiki/natural_language_processing) | [Transformers](/wiki/transformer), text classifiers | Embedding, MultiHeadAttention, TextVectorization |
| [Time series forecasting](/wiki/time_series_analysis) | LSTM, GRU, temporal convolutional networks | LSTM, GRU, Conv1D, custom training loops |
| [Generative models](/wiki/generative_model) | [VAEs](/wiki/variational_autoencoder), [GANs](/wiki/generative_adversarial_network) | Model subclassing, custom training loops, GradientTape |
| [Recommendation systems](/wiki/recommender_system) | Collaborative filtering, two-tower models | Embedding, Functional API |
| Speech recognition | CTC-based models, conformers | LSTM, Conv1D, attention layers |

## How does tf.keras compare with PyTorch and JAX?

| Feature | tf.keras | PyTorch (torch.nn) | JAX + Flax/Haiku |
|---|---|---|---|
| Default execution | Eager + graph via @tf.function | Eager | Functional transforms (jit, grad) |
| Model definition | Sequential, Functional, Subclassing | Module subclassing | Functional or module-based |
| Built-in training loop | `model.fit()` | Manual loop (or third-party trainers) | Manual loop |
| Pretrained models | tf.keras.applications | torchvision.models | Third-party libraries |
| Distributed training | tf.distribute strategies | torch.distributed, FSDP | pjit, pmap |
| Deployment | TF Serving, TF Lite, TF.js | TorchServe, ONNX | TF via SavedModel export |
| Mixed precision | One-line global policy | torch.cuda.amp | jmp library |

Notably, because Keras 3 can run on the PyTorch and JAX backends, the boundary between these ecosystems has softened: a single Keras codebase can target TensorFlow, JAX, or PyTorch as its computational engine. [13]

## What are the limitations of tf.keras?

While tf.keras covers most deep learning workflows, it has some constraints worth noting.

- **Graph-mode debugging.** Although eager execution is the default, using `@tf.function` for performance can make debugging harder because Python control flow may behave differently inside traced functions.
- **Subclassed model serialization.** Subclassed models cannot be fully serialized to JSON or YAML, which can complicate model sharing and reproducibility.
- **Dynamic architectures.** Models requiring truly dynamic architectures (where the graph structure changes from sample to sample) are more naturally expressed in frameworks with fully imperative execution.
- **Backend migration.** Code that uses TensorFlow-specific operations inside Keras layers will not be portable to JAX or PyTorch backends in Keras 3 without modification. [13]

## See also

- [Keras](/wiki/keras)
- [TensorFlow](/wiki/tensorflow)
- [Estimator (tf.estimator)](/wiki/estimator_tf_estimator)

## References

1. Chollet, F. (2015). Keras. GitHub. https://github.com/keras-team/keras
2. Abadi, M. et al. (2015). "TensorFlow: Large-Scale Machine Learning on Heterogeneous Systems." Software available from tensorflow.org. https://arxiv.org/abs/1603.04467
3. Abadi, M. et al. (2016). "TensorFlow: A System for Large-Scale Machine Learning." *Proceedings of the 12th USENIX Symposium on Operating Systems Design and Implementation (OSDI '16)*. https://arxiv.org/abs/1605.08695
4. Chollet, F. (2017). *Deep Learning with Python*. Manning Publications.
5. TensorFlow Team. (2018). "Standardizing on Keras: Guidance on High-level APIs in TensorFlow 2.0." TensorFlow Blog. https://blog.tensorflow.org/2018/12/standardizing-on-keras-guidance.html
6. TensorFlow Documentation. "Keras: The high-level API for TensorFlow." https://www.tensorflow.org/guide/keras
7. TensorFlow Documentation. "The Functional API." https://www.tensorflow.org/guide/keras/functional_api
8. TensorFlow Documentation. "Training and evaluation with the built-in methods." https://www.tensorflow.org/guide/keras/training_with_built_in_methods
9. TensorFlow Documentation. "Writing a training loop from scratch." https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch
10. TensorFlow Documentation. "Distributed training with Keras." https://www.tensorflow.org/tutorials/distribute/keras
11. TensorFlow Documentation. "Mixed precision." https://www.tensorflow.org/guide/mixed_precision
12. TensorFlow Documentation. "Save, serialize, and export models." https://www.tensorflow.org/guide/keras/serialization_and_saving
13. Keras Team. (2023). "Introducing Keras 3." https://keras.io/keras_3/
14. TensorFlow Blog. (2024). "What's new in TensorFlow 2.16." https://blog.tensorflow.org/2024/03/whats-new-in-tensorflow-216.html

