JAX is a high-performance numerical computing library developed by Google that combines the familiar interface of NumPy with composable function transformations for automatic differentiation, vectorization, parallelization, and just-in-time compilation. Built on top of XLA (Accelerated Linear Algebra), JAX enables researchers and engineers to write Python code that runs efficiently on CPUs, GPUs, and Google TPUs. Originally created by James Bradbury, Roy Frostig, and colleagues at Google Brain, JAX was first released in December 2018 and has since become the framework of choice for training some of the world's largest AI models, including Google's PaLM and Gemini families [1].
JAX's design draws from two primary intellectual lineages. The first is Autograd, a Python library for automatic differentiation of native Python and NumPy code created by Dougal Maclaurin, David Duvenaud, and Ryan Adams at Harvard. Autograd demonstrated that it was possible to differentiate through arbitrary Python control flow, and its design influenced not only JAX but also PyTorch and other frameworks [2].
The second lineage is XLA (Accelerated Linear Algebra), a domain-specific compiler originally developed for TensorFlow that optimizes linear algebra computations for various hardware backends. XLA performs whole-program optimization, including operator fusion, memory layout optimization, and hardware-specific code generation for GPUs and TPUs. By targeting XLA as its compilation backend, JAX inherits powerful optimization capabilities without requiring users to write hardware-specific code [1].
The name "JAX" is sometimes expanded as "Just After eXecution," though the official project describes it simply as "composable transformations of Python+NumPy programs" [1].
A nascent version of JAX, supporting only automatic differentiation and compilation to XLA, was described in a paper that appeared at SysML 2018. The full open-source release came in December 2018. The core team included James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang [2].
Following the merger of Google Brain and DeepMind into Google DeepMind in April 2023, JAX development continued with backing from the combined organization. The project moved its GitHub repository from google/jax to jax-ml/jax to reflect its broader community governance, though Google DeepMind remains the primary contributor.
JAX's central design principle is that powerful capabilities can be expressed as composable transformations of pure Python functions. Rather than building these capabilities into a framework's runtime (as PyTorch does with autograd) or requiring a special graph language (as early TensorFlow did), JAX provides a small set of function transformations that can be freely combined.
The grad transformation takes a Python function and returns a new function that computes its gradient. JAX supports both forward-mode and reverse-mode automatic differentiation, as well as higher-order derivatives. Crucially, grad is itself composable: you can take the gradient of a gradient to compute Hessians, or combine grad with other transformations [1].
import jax
import jax.numpy as jnp
def loss_fn(params, x, y):
predictions = jnp.dot(x, params)
return jnp.mean((predictions - y) ** 2)
# grad returns a function computing the gradient w.r.t. first argument
grad_fn = jax.grad(loss_fn)
JAX's automatic differentiation works by tracing the function and building a representation of the computation that can be differentiated symbolically. Because it operates on pure functions (functions without side effects), the differentiation is mathematically well-defined and reliable [3].
The jit transformation compiles a Python function using XLA, producing optimized machine code for the target hardware. JIT compilation traces the function with abstract values to determine the computation's structure, then hands the resulting program to XLA for optimization and code generation [1].
XLA applies a range of optimizations, including:
The compilation overhead is incurred only on the first call (and when input shapes change), with subsequent calls executing the compiled code directly. This means JAX programs can achieve performance close to hand-tuned kernels while being written in pure Python [1].
The vmap (vectorized map) transformation takes a function that operates on single examples and automatically vectorizes it to operate on batches. This eliminates the need to manually write batched versions of functions, a common source of bugs and complexity in deep learning code [1].
For example, if you have a function that computes the loss for a single data point, vmap can automatically create a version that processes an entire batch in parallel. This is particularly useful for implementing per-sample gradient computations, which are needed for techniques like differentially private training.
The pmap (parallel map) transformation distributes computation across multiple devices (such as multiple TPU cores or GPUs). It works by replicating the function across devices and executing it in parallel, with built-in support for collective operations like all-reduce for synchronizing gradients during distributed training [1].
More recently, JAX has introduced shard_map as a more flexible replacement for pmap. While pmap replicates the computation across devices, shard_map allows explicit control over how data and computation are partitioned. The older xmap function has been removed in favor of shard_map [4].
The critical insight of JAX's design is that these transformations compose freely. You can JIT-compile a vmapped, gradient-computing function that runs in parallel across devices, and each transformation is orthogonal to the others:
# Compose: parallel, vectorized, JIT-compiled gradient computation
parallel_batched_grad = jax.pmap(jax.vmap(jax.jit(jax.grad(loss_fn))))
This composability arises from JAX's functional programming paradigm. Because JAX functions are expected to be pure (no side effects, no mutation of external state), each transformation can reason about the function independently and compose predictably with the others [3].
| Transformation | Purpose | Input | Output |
|---|---|---|---|
grad | Automatic differentiation | Function f(x) | Function f'(x) computing gradient |
jit | XLA compilation | Function f(x) | Compiled, optimized version of f |
vmap | Automatic vectorization | Function f(x) for single example | Batched function f(X) |
pmap | Multi-device parallelism | Function f(x) | Function replicated across devices |
shard_map | Explicit sharding | Function f(x) | Function with custom device sharding |
JAX takes a deliberately opinionated approach to programming style by requiring (or strongly encouraging) functional purity. In practice, this means:
No in-place mutation. JAX arrays are immutable. Instead of x[i] = v, JAX uses x = x.at[i].set(v), which returns a new array. This requirement stems from the need for transformations like grad and jit to reason about the computation unambiguously.
Explicit random state. Unlike NumPy, which uses a global random state, JAX requires explicit passing of pseudo-random number generator (PRNG) keys. This makes random computations reproducible and compatible with transformations like jit and pmap, since there is no hidden state that could differ across devices or compilations.
Pure functions. Functions passed to JAX transformations should not read from or write to external state. All inputs should be function arguments, and all outputs should be return values.
This functional discipline has tradeoffs. On one hand, it makes JAX programs easier to reason about, optimize, and parallelize. On the other hand, it imposes a learning curve for developers accustomed to PyTorch's imperative, mutation-heavy style. Managing explicit state (model parameters, optimizer state, random keys) requires patterns that feel unfamiliar to many practitioners [3].
JAX provides jax.numpy, a near-complete reimplementation of the NumPy API that operates on JAX arrays rather than NumPy arrays. Most NumPy code can be converted to JAX simply by replacing import numpy as np with import jax.numpy as jnp. The familiar NumPy functions (array creation, linear algebra, indexing, broadcasting) work identically, but the underlying arrays are backed by XLA and can reside on GPUs or TPUs [1].
JAX also provides jax.scipy with reimplementations of many SciPy functions, extending the NumPy-compatible surface area to optimization, signal processing, and statistical functions. As of 2025, JAX supports NumPy 2.0 compatibility [4].
JAX intentionally provides only low-level numerical primitives. Neural network abstractions, optimizers, and training loops are provided by separate libraries in the JAX ecosystem.
Flax is the primary neural network library for JAX, developed by Google. It provides modules for defining layers, managing parameters, and organizing model code. Flax has gone through several API iterations:
Flax is used in hundreds of projects both in the open-source community (including Hugging Face model implementations) and at Google (including PaLM, Imagen, Scenic, and Big Vision) [5].
Haiku was developed by DeepMind as a JAX-based neural network library with a design inspired by Sonnet (DeepMind's TensorFlow library). It was used extensively within DeepMind and gained significant external adoption. However, as of July 2023, Google DeepMind recommends that new projects adopt Flax instead of Haiku, as Flax has a superset of features, a larger development team, and broader adoption outside of Alphabet [6].
Optax is a gradient processing and optimization library for JAX, developed by Google DeepMind. It provides composable building blocks for constructing optimizers: gradient transformations (scaling, clipping, momentum) can be chained together to build standard optimizers like Adam, SGD, or custom variants. This composable design mirrors JAX's own philosophy [6].
Orbax provides checkpointing and persistence utilities for JAX, designed to work at any scale from single-device to large-scale distributed training. It supports asynchronous checkpointing (overlapping I/O with computation), multi-tier checkpointing, and is compatible with various storage backends. Orbax aims to unify what were previously fragmented checkpointing implementations across the JAX ecosystem [5].
Grain is a data loading library for JAX that provides deterministic data pipelines essential for reproducible large-scale training runs. It is designed to work efficiently with JAX's distributed training primitives.
Google maintains the JAX AI Stack, a curated collection of these libraries with pinned versions verified to work together through integration tests. The stack includes Flax for model authoring, Optax for optimization, Grain for data loading, and Orbax for checkpointing [5].
| Library | Purpose | Maintainer |
|---|---|---|
| Flax | Neural network modules and model definition | |
| Optax | Gradient processing and optimization | Google DeepMind |
| Orbax | Checkpointing and persistence | |
| Grain | Deterministic data loading | |
| Haiku | Neural network library (legacy, use Flax for new projects) | Google DeepMind |
| Pallas | Custom kernel authoring for TPUs and GPUs |
JAX's combination of XLA compilation, functional purity, and multi-device parallelism has made it the framework of choice for training many of the world's largest AI models.
Google DeepMind uses JAX to train its most advanced models. PaLM (Pathways Language Model), a 540-billion parameter large language model announced in April 2022, was trained using JAX on Google's TPU v4 pods. The Gemini family of multimodal models, which succeeded PaLM and powers Google's AI products, is also trained with JAX. The framework's native TPU support and efficient multi-device parallelism make it particularly well-suited for training on Google's custom accelerator hardware [7].
Several major AI companies have adopted JAX for their model training infrastructure:
In the research community, JAX is popular for work requiring custom differentiation (such as physics simulations and scientific computing), meta-learning, and any setting where the functional programming model aligns well with the mathematical structure of the problem [7].
JAX and PyTorch represent fundamentally different philosophies for building ML frameworks, though both are capable of similar tasks.
| Aspect | JAX | PyTorch |
|---|---|---|
| Programming paradigm | Functional (pure functions, immutable arrays) | Imperative (mutation, object-oriented) |
| Execution model | Trace-and-compile (via XLA) | Eager by default; optional compilation (torch.compile) |
| Automatic differentiation | Function transformation (grad) | Tape-based (autograd) |
| Vectorization | Built-in (vmap) | Manual batching or torch.vmap (added later) |
| Multi-device parallelism | Built-in (pmap, shard_map) | DistributedDataParallel, FSDP |
| Random number generation | Explicit PRNG keys | Global state (torch.manual_seed) |
| TPU support | Native, first-class | Via PyTorch/XLA (experimental) |
| GPU support | Via XLA (CUDA, ROCm) | Native CUDA, ROCm |
| Neural network library | Separate (Flax, Haiku) | Built-in (torch.nn) |
| Debugging | Harder (JIT compilation obscures errors) | Easier (eager execution, standard debuggers) |
| Community size | Smaller but growing | Much larger |
| Research paper adoption | ~5-10% of papers with code | ~60% of papers with code |
| Production deployment | Less mature tooling | TorchServe, ExecuTorch, ONNX |
JAX is a strong choice in several scenarios:
grad transformation makes it straightforward to compute Hessians, Jacobians, and other higher-order derivatives.PyTorch is typically the better choice for:
Pallas is JAX's framework for writing custom kernels that run on TPUs and GPUs. It provides a high-level interface for defining kernels using blocked (tiled) computation patterns, similar in spirit to Triton for NVIDIA GPUs. As of 2025, Pallas uses XLA instead of Triton's Python APIs to compile GPU kernels, providing a unified kernel authoring experience across hardware platforms [4].
JAX now includes the Shardy partitioner as the default for automatic sharding of computations across devices. Shardy determines how to distribute tensors and computations across a mesh of devices, replacing earlier partitioning systems. Users can provide sharding annotations (via NamedSharding and abstract mesh specifications) to guide the partitioner, and JAX's export system now supports serializing these sharding specifications [4].
JAX added jax.thread_guard, a context manager that detects when devices are used by multiple threads in multi-controller JAX setups. This helps catch concurrency bugs in complex distributed training configurations [4].
As of early 2026, JAX continues to mature as a framework for high-performance numerical computing and large-scale model training.
Version progress. JAX has reached version 0.7.x in NVIDIA's container releases, with the core library at version 0.4.x in the JAX AI Stack. The minimum supported Python version is 3.10, and the framework requires CUDA 12.1 or newer (CUDA 11.8 support has been dropped) [4].
Google I/O 2025. JAX was featured prominently at Google I/O 2025, with technical sessions demonstrating its use across Google's AI product stack. Google continues to position JAX as its primary framework for AI research and production model training [8].
Growing ecosystem. The JAX ecosystem continues to expand. Google DeepMind released Simply, a minimal and scalable research codebase in JAX designed for rapid iteration on frontier LLM research by both humans and AI agents. Jeo provides JAX-based training for geospatial and Earth observation models. The broader research community continues to produce JAX-based libraries for reinforcement learning, robotics, scientific simulation, and other domains [6].
Adoption trajectory. While JAX's overall adoption remains smaller than PyTorch's, its influence is disproportionate in the large-scale training space. Multiple frontier AI labs (Google DeepMind, Anthropic, xAI) rely on JAX for their most compute-intensive training runs. The framework's focus on composability and performance at scale positions it well for the continuing trend toward ever-larger models trained on ever-larger clusters [7].
Challenges. JAX's primary challenges remain its steeper learning curve (particularly the functional programming paradigm and explicit state management), its smaller community relative to PyTorch, and fewer available pre-trained models and tutorials. The ecosystem fragmentation between Flax and Haiku (now being consolidated toward Flax) has also been a historical friction point, though this is being resolved. Debugging JIT-compiled code remains harder than debugging PyTorch's eager execution, although improvements to error messages and tracing tools have helped.
JAX occupies a distinctive niche in the ML framework landscape. It may never match PyTorch's breadth of adoption, but its principled design, powerful compilation stack, and suitability for large-scale training ensure that it will remain a critical piece of the AI infrastructure for years to come.