Keras
Last reviewed
May 24, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v3 · 6,350 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 24, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v3 · 6,350 words
Add missing citations, update stale details, or suggest a clearer explanation.
Keras is an open-source, high-level neural network API written in Python, designed to simplify the process of building, training, and deploying deep learning models. Created by François Chollet and first released on March 27, 2015, Keras provides a user-friendly interface that abstracts away much of the complexity involved in constructing neural networks.[1][2] It is licensed under the Apache 2.0 license and has grown into one of the most widely adopted deep learning libraries in the world, with the Keras team reporting more than 2.5 million developers using the framework as of late 2023.[3]
Keras acts as a frontend for lower-level computational frameworks. In its current iteration, Keras 3, it supports TensorFlow, JAX, PyTorch, and OpenVINO as backends, allowing developers to write code once and run it across multiple frameworks without modification.[3][4] As of May 2026, the current stable release is version 3.14.1, published May 7, 2026.[2][5]
| Field | Value |
|---|---|
| Original author | François Chollet |
| Initial release | March 27, 2015 |
| Current stable release | 3.14.1 (May 7, 2026)[5] |
| Repository | github.com/keras-team/keras |
| Written in | Python |
| License | Apache License 2.0 |
| Supported backends (3.x) | TensorFlow, JAX, PyTorch, OpenVINO |
| Reported users | 2.5 million+ developers (2023)[3] |
| Minimum Python | 3.11 (from Keras 3.13)[6] |
The development of Keras is closely tied to the career of its creator, François Chollet, a French software engineer and AI researcher.
Chollet began work on what became Keras while researching recurrent neural networks (RNNs). At the time he found no good reusable open-source implementation of RNNs and LSTMs for the Python data-science ecosystem. The available options had clear limitations: Caffe was popular in computer vision but only worked for narrow use cases and was not very extensible, while Torch 7 required coding in Lua, which lacked the advantages of the Python ecosystem around NumPy, SciPy, and scikit-learn.[7][8] Chollet decided to build his own library, and that effort became Keras.
According to Chollet and Wikipedia's project history, Keras was developed as part of the research effort of project ONEIROS (Open-ended Neuro-Electronic Intelligent Robot Operating System), a research initiative on which Chollet was working in early 2015.[2][9] The first commit was published publicly on GitHub on March 27, 2015.[2][7] When Keras launched, one of its first differentiators was that it was the first Python deep learning library to offer support for both recurrent and convolutional networks in a single API.[7]
The name "Keras" comes from the Ancient Greek word keras (κέρας, meaning "horn"), a reference to the literary image of the "Gate of Horn" from Homer's Odyssey, through which true visions pass to mortals.[2] Chollet released the first version in March 2015 and joined Google shortly afterward.[10]
In its early versions, Keras supported multiple backends. Keras was originally built on top of Theano, the University of Montreal symbolic-math library, and added TensorFlow as a backend in 2016. Microsoft's Cognitive Toolkit (CNTK) and Intel's PlaidML (which targeted non-NVIDIA GPUs via OpenCL) were also supported in various 2.x releases.[11][12] This backend-agnostic design was one of Keras's defining features: users could write model code once and switch between backends by changing a single configuration setting in ~/.keras/keras.json.[13]
Keras 2, released in March 2017, stabilized the API and brought improvements to the layer system, model saving, and preprocessing utilities. Through this period, Keras's user base grew rapidly, and it became one of the most popular deep learning tools on Kaggle, in industry, and in university courses.[7] Keras 2.3.0, released in September 2019, was the first version of multi-backend Keras with full TensorFlow 2 support and was announced as the last major release of the multi-backend line.[14]
When TensorFlow 2.0 launched in September 2019, Keras was integrated as TensorFlow's official high-level API under the tf.keras namespace.[15][16] Earlier, in December 2018, the TensorFlow team had already announced that they were standardizing on Keras and would deprecate or remove competing high-level APIs (such as tf.estimator and tf.slim).[17] This integration gave Keras access to TensorFlow's full ecosystem, including TensorFlow Serving, TensorFlow Lite, and TensorFlow.js for deployment across servers, mobile devices, and web browsers.[15][16]
During this period (Keras 2.4 through 2.15), TensorFlow was the only supported backend. The standalone multi-backend Keras package was no longer maintained in favor of tf.keras, and the team explicitly recommended that users switch.[11][14] The integration also added TensorFlow-specific features such as eager execution by default, TPU training, native support for distributed training via tf.distribute.Strategy, and the SavedModel format.[15][16]
In 2023, Chollet announced Keras 3, a full rewrite of the library that restored multi-backend support. The project was developed under the codename "Keras Core" during its initial development phase (April to July 2023) and a public beta test (July to September 2023). In September 2023, the project repository at keras-team/keras-core was renamed back to keras-team/keras, and the official Keras 3.0 release shipped on November 28, 2023.[18][19][20]
Keras 3 supports four backends:
| Backend | Use case | Notes |
|---|---|---|
| JAX | High-performance training and inference | Typically delivers the best performance on GPU, TPU, and CPU for many architectures[21] |
| TensorFlow | Production deployment, mobile/web | Access to TF Serving, TF Lite, TF.js ecosystem and tf.distribute[4] |
| PyTorch | Research, integration with PyTorch ecosystem | Keras layers function as native PyTorch Modules; full support including DistributedDataParallel[4] |
| OpenVINO | Inference-only optimization | Added in Keras 3.8 (January 2025) for accelerated CPU/iGPU/NPU inference[22] |
The OpenVINO backend, contributed in collaboration with Intel, supports inference but not training, because OpenVINO does not implement gradient computation; users typically train with JAX, TensorFlow, or PyTorch and switch to OpenVINO for deployment.[22]
As of May 2026, the latest stable version is Keras 3.14.1.[5] Starting with version 3.13.0, Keras requires Python 3.11 or higher.[6] TensorFlow 2.16 and later versions ship with Keras 3 as the default Keras implementation, while the legacy Keras 2 remains available through the tf_keras maintenance package (installed via pip install tf_keras and selected by setting the environment variable TF_USE_LEGACY_KERAS=1 before importing TensorFlow).[23]
François Chollet earned a Master of Engineering degree from ENSTA Paris (part of the Polytechnic Institute of Paris) in 2012. He created Keras in 2015 and joined Google the same year, where he served as a Senior Staff Engineer for over nine years before departing in November 2024.[10][24]
Beyond Keras, Chollet has made several contributions to the AI field:
Although Chollet left Google, Google continued sponsoring Keras development, and Chollet stated he would stay involved with the project from outside Google.[24]
Keras was designed around a small number of guiding principles that Chollet has articulated repeatedly in interviews and in the project documentation: user-friendliness, modularity, ease of extensibility, and what he calls "progressive disclosure of complexity."[32][33]
The Keras documentation describes the project as one that "follows the principle of progressive disclosure of complexity": simple workflows should be quick and easy, while arbitrarily advanced workflows should be possible via a clear path that builds on what users already know.[4] In practice this means that beginners can train a model in fewer than ten lines using high-level APIs (Sequential, compile, fit), while advanced users can override train_step, write fully custom training loops, or even drop down to native backend code, all using the same components.[32]
A Keras model is a graph of standalone, configurable modules. Layers, loss functions, metrics, optimizers, weight initializers, regularizers, and callbacks are independent objects that can be combined, swapped, or subclassed.[4] The Functional API in particular treats layers as callable objects on tensors, which makes shared layers, multi-input/multi-output models, and skip connections straightforward.[34]
Keras 3 returns to the original multi-backend philosophy and goes further than the 2.x version by exposing a unified operations namespace (keras.ops) that lets users write custom components once and run them on any backend (see The keras.ops namespace below).[4]
Keras offers three distinct APIs for building neural network models, each suited to different levels of complexity and customization.[34]
The Sequential API is the simplest way to build a model in Keras. It allows users to create models by stacking layers in a linear sequence, one after another. This API is ideal for straightforward architectures where data flows through each layer in order without branching or merging.[34]
import keras
from keras import layers
model = keras.Sequential([
layers.Input(shape=(784,)),
layers.Dense(128, activation='relu'),
layers.Dropout(0.3),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
The Sequential API is best suited for beginners or for building simple models like basic classifiers and regressors. Its limitation is that it only supports single-input, single-output stacks of layers.[34]
The Functional API provides greater flexibility by allowing users to define models as directed acyclic graphs of layers. This API supports multiple inputs and outputs, shared layers, and non-linear topologies such as skip connections and residual blocks.[34][35]
inputs = keras.Input(shape=(784,))
x = layers.Dense(128, activation='relu')(inputs)
x = layers.Dropout(0.3)(x)
x = layers.Dense(64, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
The Functional API strikes a balance between ease of use and flexibility. It is the recommended approach for most use cases, including architectures with branching (such as Inception-style networks) and models that require multiple input or output tensors.[34][35]
The Subclassing API gives users full control over the model by defining a custom class that inherits from keras.Model. Users implement the __init__ method to define layers and the call method to specify the forward pass logic.[34]
class MyModel(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = layers.Dense(128, activation='relu')
self.dropout = layers.Dropout(0.3)
self.dense2 = layers.Dense(64, activation='relu')
self.out = layers.Dense(10, activation='softmax')
def call(self, inputs):
x = self.dense1(inputs)
x = self.dropout(x)
x = self.dense2(x)
return self.out(x)
This approach is suited for advanced research or highly customized models that require conditional logic, loops, or other dynamic behaviors during the forward pass. It offers maximum flexibility but requires a deeper understanding of the framework.[34][35]
| Feature | Sequential | Functional | Subclassing |
|---|---|---|---|
| Ease of use | Very easy | Moderate | Advanced |
| Multiple inputs/outputs | No | Yes | Yes |
| Shared layers | No | Yes | Yes |
| Non-linear topology | No | Yes | Yes |
| Dynamic forward pass | No | No | Yes |
Model visualization (plot_model) | Yes | Yes | Limited |
| Best for | Beginners, simple models | Most use cases | Research, custom architectures |
The three APIs are not mutually exclusive: a Functional model can include a subclassed Layer, and a subclassed Model can use the Functional API internally. Mixing and matching is encouraged when it improves clarity.[34]
Keras 3 introduced the keras.ops namespace, which provides a unified set of operations that work identically across all backends. This includes a NumPy-compatible API (for example, ops.matmul, ops.sum, ops.stack, ops.einsum) and neural-network-specific functions (for example, ops.softmax, ops.binary_crossentropy, ops.conv).[4][36] Any custom layer, loss, metric, or optimizer written with keras.ops runs on JAX, TensorFlow, PyTorch, and (for inference) OpenVINO. Internally, calls to keras.ops dispatch to the equivalent operation in the active backend, while preserving the same input/output semantics. According to the Keras 3 announcement, numerical results match to within 1e-7 precision in float32 across backends.[4]
All stateful objects in Keras 3 (layers, models, optimizers, metrics) expose a parallel stateless API for use in pure-functional contexts, particularly JAX. Layers and models have a stateless_call() method that mirrors __call__, optimizers have stateless_apply() mirroring apply(), and metrics have stateless_update_state() and stateless_result().[4][37] This makes Keras components usable inside jax.grad, jax.jit, and jax.pmap without further wrapping.[37]
model.fit(), model.evaluate(), and model.predict() accept input data in many formats regardless of the active backend, including NumPy arrays, pandas DataFrames, tf.data.Dataset, torch.utils.data.DataLoader, and keras.utils.PyDataset (a parallelizable Python generator).[4][38] A model running on the JAX backend can iterate over a PyTorch DataLoader, and a model on the PyTorch backend can consume a tf.data.Dataset.[4]
Keras 3 includes a distribution API (keras.distribution) that simplifies data parallelism and model parallelism.[4][39] Initially implemented on the JAX backend (with TensorFlow and PyTorch implementations rolling out across the 3.x line), it lets users distribute training across many GPUs or TPUs with a few lines of code. The core abstractions are DeviceMesh (a logical grid of accelerators, analogous to jax.sharding.Mesh) and TensorLayout (which describes how a tensor is sharded across the mesh).[39]
For pure data parallelism, two lines suffice:
distribution = keras.distribution.DataParallel(
devices=keras.distribution.list_devices()
)
keras.distribution.set_distribution(distribution)
For model parallelism, users define a LayoutMap that matches variable names with regular expressions and assigns each match a TensorLayout. The underlying framework distributes the program and tensors according to the sharding directives through single-program multiple-data (SPMD) expansion.[39] The API keeps model definition, training logic, and sharding configuration separate, so a model can be scaled up by editing only the layout map.
Keras provides access to more than 40 pre-trained model architectures through Keras Applications, including ResNet and ResNetV2, VGG16/VGG19, InceptionV3 and InceptionResNetV2, Xception, MobileNet (v1/v2/v3), EfficientNet (B0 to B7), EfficientNetV2, NASNet, DenseNet, and ConvNeXt.[40] These models come with pre-trained weights (typically trained on ImageNet with 1,000 classes) and can be used for prediction, feature extraction, fine-tuning, and transfer learning.[40]
For the JAX and TensorFlow backends, models support XLA (Accelerated Linear Algebra) compilation. model.compile(..., jit_compile="auto") is the default and enables XLA where possible. XLA fuses operations into optimized kernels for the target hardware (CPU/GPU/TPU) and is one of the main reasons JAX often achieves the best training throughput on Keras 3 benchmarks.[4][21]
Keras organizes its layer library into 16 categories. The table below lists the most commonly used layers.[41]
| Layer | Category | Description |
|---|---|---|
Dense | Core | Fully connected layer; each neuron connects to every neuron in the previous layer |
Conv2D | Convolution | 2D convolution layer for processing image data |
LSTM | Recurrent | Long Short-Term Memory layer for sequential data; handles the vanishing gradient problem |
GRU | Recurrent | Gated Recurrent Unit; a simpler alternative to LSTM with comparable performance |
Embedding | Core | Maps integer indices (e.g., word IDs) to dense vectors; used in NLP models |
Dropout | Regularization | Randomly sets a fraction of input units to zero during training to prevent overfitting |
BatchNormalization | Normalization | Normalizes layer inputs to have zero mean and unit variance, stabilizing training |
LayerNormalization | Normalization | Normalizes across features rather than the batch dimension; common in transformers |
MultiHeadAttention | Attention | Implements the multi-head attention mechanism used in transformer architectures |
Flatten | Reshaping | Flattens a multi-dimensional input into a 1D vector |
MaxPooling2D | Pooling | Downsamples spatial dimensions by taking the maximum value in each pooling window |
Concatenate | Merging | Concatenates a list of inputs along a specified axis |
In addition to these, Keras provides preprocessing layers for text, image, and audio data; activation layers (ReLU, Softmax, GELU, Swish); weight initializers (GlorotNormal, HeNormal); weight regularizers (L1, L2); and backend-specific layers (TorchModuleWrapper, JaxLayer, FlaxLayer) for interoperability with native PyTorch Modules, TensorFlow SavedModels, and JAX/Flax layers.[4][41]
Keras provides a streamlined workflow for training and evaluating models, centered around three methods: compile, fit, and evaluate/predict.[42]
Before training, the model must be compiled with an optimizer, a loss function, and (optionally) metrics to monitor.[42]
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
Optimizers include SGD, RMSprop, Adam, AdamW, Adagrad, Adadelta, Adamax, Nadam, Ftrl, and Lion. Losses cover regression (MeanSquaredError, MeanAbsoluteError, Huber), classification (BinaryCrossentropy, CategoricalCrossentropy, SparseCategoricalCrossentropy, Focal variants), and probabilistic objectives (KLDivergence, Poisson).[4]
The model.fit() method is the primary training function. For each epoch, it performs the following steps:[42]
history = model.fit(
x_train, y_train,
epochs=20,
batch_size=32,
validation_split=0.2,
callbacks=[...]
)
Keras 3 supports multiple data pipeline formats, including NumPy arrays, tf.data.Dataset, torch.utils.data.DataLoader, Pandas DataFrames, and keras.utils.PyDataset, regardless of which backend is active.[4][38]
model.evaluate() calculates the loss and metrics on a test dataset, providing a measure of the model's performance on unseen data.[42]
test_loss, test_accuracy = model.evaluate(x_test, y_test)
model.predict() generates output predictions for new input data without computing loss or metrics.
predictions = model.predict(new_data)
For more advanced use cases, users can override the train_step() method to customize the training logic while still using model.fit().[4] This is the recommended path for GAN training, self-supervised learning, contrastive losses, and similar setups. Alternatively, Keras components can be used inside fully custom training loops written in native JAX (using jax.grad, jax.jit, and the stateless API), TensorFlow (using tf.GradientTape), or PyTorch (using torch.autograd and optimizer.step()).[4]
Callbacks are objects passed to model.fit() that can perform actions at various stages of training, such as at the start or end of an epoch, or before or after processing a batch. Keras includes several built-in callbacks for common tasks.[43]
| Callback | Purpose |
|---|---|
EarlyStopping | Stops training when a monitored metric (e.g., validation loss) has stopped improving for a specified number of epochs (patience). Can restore weights from the best epoch via restore_best_weights.[44] |
ModelCheckpoint | Saves the model or its weights periodically or whenever performance on a monitored metric improves; supports both per-epoch and per-batch save frequencies.[45] |
ReduceLROnPlateau | Reduces the learning rate when a monitored metric has stopped improving, helping the model escape plateaus. |
TensorBoard | Logs training metrics, model graphs, histograms, and embeddings for visualization in TensorBoard.[46] |
LearningRateScheduler | Adjusts the learning rate according to a user-defined schedule function at each epoch. |
CSVLogger | Streams epoch results (loss, metrics) to a CSV file. |
ProgbarLogger | Displays a progress bar during training. |
BackupAndRestore | Periodically saves training state so a job can resume after preemption. |
RemoteMonitor | Streams events to a remote HTTP endpoint. |
Users can also create custom callbacks by subclassing keras.callbacks.Callback and overriding methods like on_epoch_end, on_batch_begin, on_train_end, on_test_batch_end, and on_predict_begin.[43]
Originally, the Keras ecosystem included two separate domain-specific libraries: KerasCV for computer vision and KerasNLP for natural language processing. As AI models increasingly became multimodal (for example, chat-based large language models with image inputs, or vision tasks that leverage text encoders), maintaining separate domain libraries became impractical.[47]
In October 2024, KerasCV and KerasNLP were consolidated into a single unified library called KerasHub.[47][48] KerasHub is a pretrained-modeling library that provides Keras 3 implementations of popular model architectures paired with pretrained checkpoints available on Kaggle Models and the Hugging Face Hub. Models work across the JAX, TensorFlow, and PyTorch backends for both training and inference.[47][49]
KerasHub launched with 37 pretrained models, including:[47]
Features include LoRA and QLoRA fine-tuning for resource-efficient model adaptation, weight quantization (int8 and int4), model publishing to Kaggle and Hugging Face, and large-scale model-parallel retraining.[47][49]
Existing code using keras_nlp imports continues to work; migration only requires updating import statements from keras_nlp to keras_hub. The keras-nlp GitHub repository was renamed keras-hub while preserving backward compatibility.[50]
KerasTuner is a hyperparameter optimization framework for Keras that automates the search for optimal hyperparameter configurations. It exposes a define-by-run search space (using hp.Int, hp.Float, hp.Choice, hp.Boolean) and includes three built-in search algorithms: random search, Bayesian optimization, and Hyperband.[51] Researchers can also implement custom tuners by subclassing keras_tuner.engine.tuner.Tuner.[51]
AutoKeras is an automated machine learning (AutoML) library built on Keras. Developed by the DATA Lab at Texas A&M University and first released in November 2017, AutoKeras automatically searches for the best model architecture and hyperparameters for a given dataset, with task APIs for image classification, image regression, text classification, structured data tasks, and time-series forecasting.[52][53] AutoKeras was published as a journal paper in JMLR in 2023.[52]
The distribution API (keras.distribution), described above under Key features of Keras 3, is part of the core Keras package and provides data and model parallelism with minimal code changes. It was first available on the JAX backend in the Keras 3.0 release and is being extended to other backends in subsequent 3.x versions.[4][39]
The Keras team publishes a benchmark page comparing Keras 3 across backends and against Keras 2 (tf.keras). On an NVIDIA A100 (40 GB) GPU, the team reported the following representative results in late 2023:[21]
| Model and task | Keras 3 + TensorFlow | Keras 3 + JAX | Keras 3 + PyTorch |
|---|---|---|---|
| SegmentAnything inference (ms/step) | 438.50 | 376.34 | 1720.96 |
| Stable Diffusion fit (ms/step) | 392.24 | 391.21 | 823.44 |
| BERT fit (ms/step) | 214.49 | 222.37 | 808.68 |
| BERT predict (ms/step) | 466.01 | 418.72 | 1865.98 |
| Gemma fit (ms/step) | 232.52 | 273.67 | 525.15 |
| Mistral fit (ms/step) | 185.92 | 213.22 | 452.12 |
Lower values are faster; the best result in each row is bold. The Keras team noted that no single backend is best for every workload: JAX often wins on inference (especially on encoder/decoder vision models) and TensorFlow tends to win on large LLM fine-tuning where its compiler is well tuned, while PyTorch through Keras was slower in this set largely because the team had not yet enabled torch.compile integration at the time of testing.[21]
The Keras 3 announcement also reported substantial improvements over Keras 2 (tf.keras): SegmentAnything inference roughly 380% faster, Stable Diffusion training roughly 150% faster, and BERT training roughly 100% faster, on the same hardware.[4]
Keras is used across a wide range of deep learning applications:
| Domain | Examples | Common layers / models |
|---|---|---|
| Image classification | Object detection, face recognition, medical imaging | Conv2D, ResNet, EfficientNet |
| Natural language processing | Text classification, sentiment analysis, machine translation | Embedding, LSTM, Transformer, BERT |
| Generative AI | Image synthesis, text generation, data augmentation | GANs, VAEs, Stable Diffusion |
| Speech and audio | Speech recognition, audio classification | Conv1D, Whisper |
| Time series | Forecasting, anomaly detection | LSTM, GRU, Conv1D |
| Reinforcement learning | Game playing, robot control | Dense, custom training loops |
The Keras 3 announcement and documentation list several large production systems that use Keras, including the Waymo self-driving fleet and the YouTube recommendation engine.[4] In published case studies and conference talks, companies such as Netflix, Uber, Square, Yelp, Instacart, and Zocdoc have also described using Keras for various deep-learning workloads.[54] Keras is widely used on Kaggle, where it is a default deep-learning framework option in many notebooks and competitions.[7]
While Keras is consistently praised for usability and rapid prototyping, several documented criticisms have followed the library through its history.
The same high-level abstractions that make Keras easy can obscure low-level computational details, which complicates debugging when something fails deep in the computation graph. Tracebacks from inside compiled or graph-mode code can be harder to interpret than equivalent PyTorch errors, and printing arbitrary intermediate tensors is less ergonomic than in native NumPy/PyTorch code.[55][56] Empirical research has identified concrete bug-localization gaps: a 2021 study of deep neural network faults found that built-in Keras debugging utilities detected most failures but localized very few of them at the layer level, compared with specialized fault-localization tools.[57]
For training procedures that require dynamic gradient updates, multiple optimizers, custom backpropagation, or non-standard data flow (for example, advanced GAN training schedules or some reinforcement-learning algorithms), users must drop down to overriding train_step or writing fully custom training loops. While these escape hatches exist, critics argue the cliff between the easy path and the custom path is steep, and that the API can feel constraining for cutting-edge research compared with PyTorch's "everything is just Python" model.[55][58] This has reinforced the perception, particularly in academic research circles, that Keras is best suited for applied work rather than novel methods.
Industry adoption studies have consistently shown that since 2018 to 2019, PyTorch has overtaken TensorFlow/Keras as the preferred framework in academic research, especially among papers at major conferences such as NeurIPS, ICML, and CVPR.[58] Keras 3 partially addresses this gap by allowing users to keep the Keras API while using JAX or PyTorch underneath, but the research community's familiarity with native PyTorch idioms has remained a structural advantage for PyTorch.
Several backends supported in earlier Keras versions have been discontinued or fallen into disuse:[11][12]
tf.keras.Keras and PyTorch are two of the most widely used frameworks for deep learning, but they take different approaches.[58][59]
| Aspect | Keras | PyTorch |
|---|---|---|
| Abstraction level | High-level API | Lower-level framework |
| Ease of use | Beginner-friendly; minimal boilerplate | Requires more code but feels Pythonic |
| Debugging | Relies on backend tools; can be less transparent | Standard Python debugging; eager by default |
| Training loop | Built-in model.fit() handles most cases | Manual training loops offer full control |
| Research adoption | Common in applied ML and industry | Dominant in academic research[58] |
| Cutting-edge models | Available through KerasHub | Most new state-of-the-art models appear first in PyTorch |
| Deployment | Strong TensorFlow ecosystem (TF Serving, TF Lite, TF.js) and ONNX export | TorchServe, TorchScript, ONNX export |
| Backend flexibility | Multi-backend (JAX, TF, PyTorch, OpenVINO) | PyTorch only |
| Performance | Can leverage JAX or XLA for best GPU/TPU performance | Optimized for GPU; mature CUDA support |
When to use Keras: Keras is well suited to rapid prototyping, educational use, and small to mid-scale production projects. It is a strong choice when backend flexibility is important or when deploying to mobile and edge devices through TensorFlow Lite or OpenVINO. Teams that want JAX's performance benefits without learning JAX's functional programming model can use Keras as a familiar interface.[4][21]
When to use PyTorch: PyTorch is preferred for cutting-edge research, when fine-grained control over training dynamics is needed, or when working with models that are primarily published in the PyTorch ecosystem. It is also the standard in most academic labs.[58]
Hybrid approach: Many teams use both frameworks. Keras 3's multi-backend support means that a model written in Keras can run on PyTorch as its backend, and Keras layers can be embedded inside native torch.nn.Module definitions, bridging the gap between the two ecosystems.[4]
Keras is not the only high-level wrapper for deep learning. The two most often compared alternatives are PyTorch Lightning and fastai.
| Aspect | Keras | PyTorch Lightning | fastai |
|---|---|---|---|
| Underlying framework | JAX, TensorFlow, PyTorch, OpenVINO | PyTorch | PyTorch |
| Initial release | 2015 | 2019 | 2018 |
| Created by | François Chollet | William Falcon (PyTorch Lightning Inc., now Lightning AI) | Jeremy Howard and Rachel Thomas |
| Primary design goal | Progressive disclosure, multi-backend portability | Code organization and scalability for research | Maximally fast results with sensible defaults |
| Training abstraction | model.fit() + callbacks | LightningModule + Trainer | Learner + callbacks |
| Hyperparameter tuning | KerasTuner | Native sweeps, optional Optuna/Ray | Built-in learning-rate finder |
| Pretrained model library | KerasHub | TorchVision/TorchHub/torchaudio | fastai pretrained models, plus PyTorch ecosystem |
| Strength | Multi-backend portability, simple fit API | Reduces boilerplate while staying close to PyTorch | High-level "best practice by default" for tabular/CV/NLP |
| Weakness | Customization can require dropping to native backend | PyTorch only; smaller pretrained library than Keras | Opinionated abstractions can be hard to escape |
PyTorch Lightning is often described as "Keras for PyTorch": a framework that wraps PyTorch with a structured training loop and callback system, similar in spirit to model.fit, but without backend flexibility.[60] fastai, developed by Jeremy Howard and Rachel Thomas at Fast.AI, is a layered API on top of PyTorch that emphasizes "best-practice defaults" and is widely used in the eponymous deep-learning courses; it is more opinionated than Keras and can require learning fastai-specific abstractions like DataBunch/DataBlock.[61]
Imagine you want to build something out of LEGO blocks. You could try to make every single tiny brick yourself from scratch, which would take forever. Or you could use a LEGO kit that already has all the special pieces sorted and labeled, with instructions showing you how to snap them together.
Keras is like that LEGO kit, but for building smart computer programs. Deep learning programs are made of building blocks called "layers" that each do a small job, like looking at pictures or reading words. Keras gives you all these building blocks pre-made, so you just pick the ones you need and snap them together in the right order.
Once you put your blocks together, you "train" your creation by showing it lots of examples (like thousands of pictures of cats and dogs) so it learns to tell them apart. Keras handles all the complicated math behind the scenes. You just say "learn from these examples" and it does the rest.
The best part is that Keras works with several different "engines" underneath (called backends), so it is like having one set of LEGO instructions that works with different brands of building blocks.