tf.keras is the high-level neural network API integrated directly into the TensorFlow machine learning framework. Developed and maintained by the Google Brain team, tf.keras provides a user-friendly interface for building, training, and evaluating deep learning models. It is based on the standalone Keras library originally created by Francois Chollet in 2015, which was formally adopted as TensorFlow's official high-level API with the release of TensorFlow 2.0 in September 2019.
tf.keras emphasizes rapid prototyping, modularity, and extensibility. It supports a range of model-building paradigms, from simple sequential stacks of layers to highly customized architectures using model subclassing. The API handles the full lifecycle of a deep learning project: data preprocessing, model construction, training with built-in or custom loops, evaluation, prediction, model saving, and deployment.
Imagine you want to build a robot that can recognize pictures of cats and dogs. Building that robot from scratch would be really hard because you would need to know how every tiny part works. tf.keras is like a box of special building blocks that snap together easily. You pick the blocks you need, connect them in a line or a pattern, show the robot lots of pictures so it can learn, and then ask it to guess what new pictures are. The building blocks handle all the complicated math inside so you can focus on what you want your robot to learn.
The relationship between Keras and TensorFlow has evolved through several distinct phases.
| Period | Milestone |
|---|---|
| March 2015 | Francois Chollet releases Keras as an open-source library supporting Theano as its backend |
| November 2015 | Google open-sources TensorFlow; Chollet refactors Keras to support TensorFlow as a backend |
| 2016 | Keras gains support for the Microsoft Cognitive Toolkit (CNTK) backend |
| Mid-2017 | Keras is integrated into TensorFlow as tf.keras, available alongside the standalone package |
| September 2019 | TensorFlow 2.0 launches with tf.keras as its official and recommended high-level API |
| 2020-2023 | Keras versions 2.4 through 2.15 support only the TensorFlow backend |
| November 2023 | Keras 3.0 launches as a multi-backend framework supporting TensorFlow, JAX, and PyTorch |
| March 2024 | TensorFlow 2.16 ships with Keras 3 as the default tf.keras; the legacy Keras 2 is published as the separate tf-keras PyPI package |
Before TensorFlow 2.0, most TensorFlow code required verbose, low-level session management and graph construction. The adoption of Keras as the standard interface simplified the developer experience and lowered the barrier to entry for newcomers to deep learning.
tf.keras is built around four guiding principles that shape its API design.
User-friendliness. The API minimizes the number of steps required to go from idea to working model. Error messages are clear, and the documentation provides extensive examples for common workflows.
Modularity. Models are built by connecting interchangeable components (layers, optimizers, loss functions, metrics, and callbacks). Each component has a consistent interface, making it straightforward to swap one for another.
Extensibility. When built-in components are insufficient, users can create custom layers, loss functions, metrics, callbacks, and even training loops by subclassing base classes and overriding specific methods.
Python-native. Unlike earlier TensorFlow APIs that relied on declarative graph definitions, tf.keras uses standard Python control flow. Combined with TensorFlow 2's default eager execution, this makes debugging intuitive because operations execute immediately and return concrete values.
tf.keras offers three progressively more flexible ways to define a model.
The Sequential API is the simplest approach. It creates a model as a linear stack of layers where each layer has exactly one input and one output. It is best suited for straightforward feedforward networks.
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
The Sequential API cannot represent models that share layers, have multiple inputs or outputs, or contain branches.
The Functional API treats layers as functions that operate on tensor objects. This allows construction of models with non-linear topologies, shared layers, and multiple inputs or outputs.
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(128, activation='relu')(inputs)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
The Functional API is the most commonly used approach in practice because it balances flexibility with clarity. The resulting model object retains full knowledge of its graph topology, which enables features like model visualization, feature extraction from intermediate layers, and transfer learning.
Model subclassing provides maximum flexibility by letting developers define the forward pass using arbitrary Python code. The user subclasses tf.keras.Model and implements the __init__ method (to define layers) and the call method (to specify the forward computation).
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(128, activation='relu')
self.dropout = tf.keras.layers.Dropout(0.2)
self.dense2 = tf.keras.layers.Dense(10, activation='softmax')
def call(self, inputs, training=False):
x = self.dense1(inputs)
x = self.dropout(x, training=training)
return self.dense2(x)
This approach is popular in research settings where non-standard architectures or dynamic computation graphs are needed. The tradeoff is that subclassed models sacrifice some of the introspection features available to Sequential and Functional models (for example, the ability to serialize the model architecture to JSON).
| Feature | Sequential | Functional | Subclassing |
|---|---|---|---|
| Multiple inputs/outputs | No | Yes | Yes |
| Shared layers | No | Yes | Yes |
| Non-linear topology | No | Yes | Yes |
| Dynamic control flow in forward pass | No | No | Yes |
| Model graph introspection | Yes | Yes | Limited |
| Serialization to JSON | Yes | Yes | No |
| Ease of use | Highest | High | Moderate |
| Typical use case | Simple feedforward nets | Most production models | Research prototypes |
Layers are the fundamental building blocks of tf.keras models. Each layer receives input tensors, applies a transformation, and produces output tensors. tf.keras provides a large library of built-in layers organized by category.
| Category | Example layers | Typical use |
|---|---|---|
| Core | Dense, Activation, Embedding, Masking | General-purpose transformations |
| Convolutional | Conv1D, Conv2D, Conv3D, SeparableConv2D, DepthwiseConv2D | Image and signal processing |
| Pooling | MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D | Spatial downsampling |
| Recurrent | SimpleRNN, LSTM, GRU, Bidirectional | Sequential and time-series data |
| Normalization | BatchNormalization, LayerNormalization, GroupNormalization | Stabilizing and accelerating training |
| Regularization | Dropout, SpatialDropout1D, GaussianNoise | Reducing overfitting |
| Attention | MultiHeadAttention, Attention, AdditiveAttention | Transformer-style models |
| Reshaping | Flatten, Reshape, Permute, RepeatVector | Tensor shape manipulation |
| Merging | Concatenate, Add, Multiply, Average | Combining multiple inputs |
| Preprocessing | TextVectorization, Normalization, Discretization, CategoryEncoding, Resizing, Rescaling | Data preprocessing within the model |
Users can define custom layers by subclassing tf.keras.layers.Layer and implementing the build and call methods. Custom layers can maintain trainable weights, support serialization, and integrate seamlessly with any model-building API.
The standard tf.keras training workflow follows a compile-fit-evaluate-predict pattern.
The compile method configures the model for training by specifying the optimizer, loss function, and metrics.
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
The fit method trains the model for a specified number of epochs. It accepts NumPy arrays, tf.data.Dataset objects, or Python generators as input. It returns a History object containing per-epoch training and validation metrics.
history = model.fit(
x_train, y_train,
epochs=10,
batch_size=32,
validation_split=0.2
)
The evaluate method computes loss and metric values on a held-out test set.
test_loss, test_acc = model.evaluate(x_test, y_test)
The predict method generates output predictions for new input samples, processing them in batches.
predictions = model.predict(x_new)
For cases that require full control over the training process, tf.keras supports custom training loops using tf.GradientTape. This is useful for implementing advanced techniques such as custom gradient clipping, multi-optimizer training, or adversarial training.
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
for epoch in range(num_epochs):
for x_batch, y_batch in train_dataset:
with tf.GradientTape() as tape:
predictions = model(x_batch, training=True)
loss = loss_fn(y_batch, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
The @tf.function decorator can be applied to the training step function to compile it into a TensorFlow graph for improved performance, while still allowing development and debugging in eager mode.
tf.keras includes a suite of gradient-based optimizers for updating model weights during training.
| Optimizer | Description | Typical use case |
|---|---|---|
| SGD | Classic stochastic gradient descent with optional momentum and Nesterov acceleration | Baseline experiments, fine-tuning |
| Adam | Adaptive learning rates using first and second moment estimates of gradients | General-purpose default optimizer |
| RMSprop | Divides learning rate by a moving average of squared gradients | Recurrent networks |
| AdaGrad | Adapts learning rate per parameter based on historical gradient magnitudes | Sparse data problems |
| AdamW | Adam with decoupled weight decay regularization | Transformer-based models |
| Nadam | Adam with Nesterov momentum | Tasks benefiting from look-ahead gradients |
| Adamax | Variant of Adam using the infinity norm | Problems with very sparse gradients |
All optimizers accept a learning_rate parameter that can be a static float or a tf.keras.optimizers.schedules.LearningRateSchedule instance (such as ExponentialDecay, CosineDecay, or PolynomialDecay) for automated learning rate adjustment during training.
tf.keras provides loss functions for common supervised and unsupervised learning tasks.
| Loss function | Task type | Description |
|---|---|---|
| MeanSquaredError | Regression | Average of squared differences between predictions and targets |
| MeanAbsoluteError | Regression | Average of absolute differences |
| BinaryCrossentropy | Binary classification | Log loss for two-class problems |
| CategoricalCrossentropy | Multi-class classification | Log loss for one-hot encoded labels |
| SparseCategoricalCrossentropy | Multi-class classification | Log loss for integer-encoded labels |
| Huber | Regression | Combines MSE and MAE; less sensitive to outliers |
| KLDivergence | Distribution matching | Measures divergence between two probability distributions |
| CosineSimilarity | Embedding learning | Measures cosine distance between prediction and target vectors |
Custom loss functions can be implemented as plain Python functions or by subclassing tf.keras.losses.Loss.
Callbacks are objects that execute actions at specific points during training (at the start or end of an epoch, before or after a batch, and so on). tf.keras includes several built-in callbacks.
| Callback | Purpose |
|---|---|
| ModelCheckpoint | Saves the model (or just its weights) periodically or when a monitored metric improves |
| EarlyStopping | Stops training when a monitored metric has not improved for a specified number of epochs (patience) |
| ReduceLROnPlateau | Reduces the learning rate by a factor when a metric has stopped improving |
| TensorBoard | Logs metrics, histograms, and graphs for visualization in TensorBoard |
| LearningRateScheduler | Adjusts the learning rate at the start of each epoch using a user-defined schedule function |
| CSVLogger | Streams epoch results to a CSV file |
| TerminateOnNaN | Stops training when a NaN loss is encountered |
| LambdaCallback | Wraps arbitrary functions for use as callbacks |
Multiple callbacks can be combined and passed to model.fit() as a list. Users can also create custom callbacks by subclassing tf.keras.callbacks.Callback.
tf.keras provides a collection of pretrained computer vision models through tf.keras.applications. These models have been trained on the ImageNet dataset and can be used for image classification, feature extraction, or fine-tuning on custom datasets.
| Model family | Example variants | Parameters (approx.) | Top-1 accuracy on ImageNet |
|---|---|---|---|
| VGG | VGG16, VGG19 | 138M (VGG16) | 71.3% (VGG16) |
| ResNet | ResNet50, ResNet101, ResNet152 | 25.6M (ResNet50) | 74.9% (ResNet50) |
| InceptionV3 | InceptionV3 | 23.9M | 77.9% |
| EfficientNet | EfficientNetB0 through B7 | 5.3M (B0) to 66M (B7) | 77.1% (B0) to 84.3% (B7) |
| MobileNet | MobileNetV2, MobileNetV3 | 3.4M (V2) | 71.3% (V2) |
| DenseNet | DenseNet121, DenseNet169, DenseNet201 | 8.1M (DenseNet121) | 75.0% (DenseNet121) |
| NASNet | NASNetMobile, NASNetLarge | 5.3M (Mobile) | 74.4% (Mobile) |
A typical transfer learning workflow involves loading a pretrained model without its classification head (by setting include_top=False), freezing the pretrained layers, adding new trainable layers for the target task, training the new layers, and optionally unfreezing some pretrained layers for fine-tuning.
base_model = tf.keras.applications.EfficientNetB0(
weights='imagenet', include_top=False, input_shape=(224, 224, 3)
)
base_model.trainable = False
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(num_classes, activation='softmax')
])
tf.keras integrates closely with the tf.data.Dataset API for building efficient input pipelines. Using tf.data, developers can load data from files, apply transformations, shuffle, batch, and prefetch data for optimal hardware utilization.
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(buffer_size=1024).batch(32).prefetch(tf.data.AUTOTUNE)
model.fit(dataset, epochs=10)
tf.keras includes built-in preprocessing layers that can be embedded directly into the model graph. This ensures that preprocessing logic is packaged with the model at serving time.
| Layer | Function |
|---|---|
| TextVectorization | Converts raw strings into integer token sequences |
| Normalization | Applies feature-wise normalization (zero mean, unit variance) |
| Discretization | Bins continuous features into discrete integer ranges |
| CategoryEncoding | Converts integer categories into one-hot or multi-hot representations |
| Resizing | Resizes images to a target height and width |
| Rescaling | Scales pixel values (for example, from [0, 255] to [0, 1]) |
| RandomFlip | Randomly flips images horizontally or vertically during training |
| RandomRotation | Applies random rotation augmentation during training |
| RandomZoom | Applies random zoom augmentation during training |
Image augmentation layers are only active during training and are automatically bypassed during inference.
tf.keras supports multiple formats for saving and loading models.
| Format | File extension | Contents | Notes |
|---|---|---|---|
| Keras native | .keras | Architecture, weights, optimizer state, training config | Recommended format for Keras 3; supports all model types |
| SavedModel | directory | Execution graph, weights, signatures | Used by TensorFlow Serving, TensorFlow Lite, and TensorFlow.js |
| HDF5 | .h5 | Architecture, weights, optimizer state | Legacy format; does not support custom objects without extra configuration |
| Weights only | .weights.h5 | Weights only | Useful when the architecture is defined in code |
The SavedModel format saves the execution graph, which means subclassed models can be loaded without access to the original source code. The HDF5 format uses the model configuration (architecture as JSON) and therefore requires the same class definitions to be available at load time.
Models saved in the SavedModel format can be converted for deployment on mobile and edge devices using TensorFlow Lite (tf.lite.TFLiteConverter) or in web browsers using TensorFlow.js.
tf.keras integrates with TensorFlow's distribution strategies (tf.distribute) to enable training across multiple GPUs, multiple machines, and TPUs with minimal code changes.
| Strategy | Hardware | Approach |
|---|---|---|
| MirroredStrategy | Multiple GPUs on one machine | Synchronous training with all-reduce gradient aggregation |
| MultiWorkerMirroredStrategy | Multiple GPUs across multiple machines | Synchronous training with all-reduce across a network |
| TPUStrategy | Google TPU pods | Synchronous training optimized for TPU interconnects |
| ParameterServerStrategy | CPU/GPU cluster with parameter servers | Asynchronous or synchronous training with centralized variable storage |
| CentralStorageStrategy | Single machine with multiple GPUs | Variables stored on CPU, computations replicated on GPUs |
Using a distribution strategy typically requires only wrapping the model creation and compilation inside a strategy.scope() block.
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential([...])
model.compile(...)
model.fit(dataset, epochs=10)
tf.keras supports mixed precision training through tf.keras.mixed_precision. Mixed precision uses 16-bit floating-point (float16 or bfloat16) for most operations while keeping critical computations in 32-bit float, which can deliver significant performance improvements on supported hardware.
On NVIDIA GPUs with Tensor Cores (compute capability 7.0 and above, such as the V100, A100, and RTX series), mixed precision can improve training throughput by over 3x. On Google TPUs, using bfloat16 mixed precision can provide up to 60% throughput improvement.
Enabling mixed precision in tf.keras requires a single line of code at the start of the program:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
For custom training loops, a LossScaleOptimizer handles dynamic loss scaling to prevent underflow in float16 gradients.
With the release of Keras 3.0 in November 2023, Keras became a multi-backend framework capable of running on TensorFlow, JAX, and PyTorch. Starting with TensorFlow 2.16 (released March 2024), tf.keras points to Keras 3 by default.
Keras 3 introduced several changes relevant to tf.keras users.
KERAS_BACKEND environment variable or keras.config.set_backend().tf-keras package from PyPI and set the environment variable TF_USE_LEGACY_KERAS=1.The transition means that tf.keras is no longer a TensorFlow-only API but part of a broader cross-framework ecosystem, though TensorFlow remains a first-class supported backend.
tf.keras is used across a wide range of deep learning applications.
| Domain | Typical models | tf.keras components used |
|---|---|---|
| Image classification | CNNs, EfficientNet, ResNet | Conv2D, MaxPooling2D, Dense, tf.keras.applications |
| Object detection | Feature pyramid networks, SSD | Functional API, custom layers |
| Natural language processing | Transformers, text classifiers | Embedding, MultiHeadAttention, TextVectorization |
| Time series forecasting | LSTM, GRU, temporal convolutional networks | LSTM, GRU, Conv1D, custom training loops |
| Generative models | VAEs, GANs | Model subclassing, custom training loops, GradientTape |
| Recommendation systems | Collaborative filtering, two-tower models | Embedding, Functional API |
| Speech recognition | CTC-based models, conformers | LSTM, Conv1D, attention layers |
| Feature | tf.keras | PyTorch (torch.nn) | JAX + Flax/Haiku |
|---|---|---|---|
| Default execution | Eager + graph via @tf.function | Eager | Functional transforms (jit, grad) |
| Model definition | Sequential, Functional, Subclassing | Module subclassing | Functional or module-based |
| Built-in training loop | model.fit() | Manual loop (or third-party trainers) | Manual loop |
| Pretrained models | tf.keras.applications | torchvision.models | Third-party libraries |
| Distributed training | tf.distribute strategies | torch.distributed, FSDP | pjit, pmap |
| Deployment | TF Serving, TF Lite, TF.js | TorchServe, ONNX | TF via SavedModel export |
| Mixed precision | One-line global policy | torch.cuda.amp | jmp library |
While tf.keras covers most deep learning workflows, it has some constraints worth noting.
@tf.function for performance can make debugging harder because Python control flow may behave differently inside traced functions.