JAX
Last reviewed
May 24, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v3 · 6,567 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 24, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v3 · 6,567 words
Add missing citations, update stale details, or suggest a clearer explanation.
JAX is a high-performance numerical computing library developed by Google that combines a NumPy-compatible interface with composable function transformations for automatic differentiation, vectorization, parallelization, and just-in-time compilation. Built on top of XLA (Accelerated Linear Algebra), JAX enables Python code to run efficiently on CPUs, GPUs, and Google Tensor Processing Units. Originally created by James Bradbury, Roy Frostig, Matthew Johnson, and colleagues at Google Brain, a nascent version of JAX was described at SysML 2018 and the full open-source release came in December of that year[1][2]. By 2026 JAX had become the framework of choice for training many of the world's largest AI models, including Google's PaLM, Gemini, Gemma, Imagen, and Veo families, as well as Anthropic's Claude models and xAI's Grok systems[3][4][5].
JAX's design draws directly from Autograd, a Python library for automatic differentiation of native Python and NumPy code created by Dougal Maclaurin, David Duvenaud, Matt Johnson, and Ryan Adams within the Harvard Intelligent Probabilistic Systems Group (HIPS). Maclaurin et al. first published "Autograd: Effortless gradients in numpy" at the ICML 2015 AutoML Workshop, demonstrating that it was possible to differentiate through arbitrary Python control flow, including loops, branches, recursion, and closures, while supporting both reverse-mode and forward-mode differentiation composed to arbitrary order[6][7]. Several of Autograd's principal authors, including Maclaurin and Johnson, later moved to Google and became core JAX developers, and JAX is often described as Autograd's successor[7].
The second intellectual lineage is XLA, a domain-specific compiler originally developed for TensorFlow that optimizes linear algebra computations for various hardware backends. XLA performs whole-program optimization, including operator fusion, memory layout optimization, and hardware-specific code generation for GPUs and TPUs. By targeting XLA as its compilation backend, JAX inherits powerful optimization capabilities without requiring users to write hardware-specific code[1][8].
The name "JAX" is sometimes expanded as "Just After eXecution," though the official project describes it simply as "composable transformations of Python+NumPy programs"[2].
The earliest published description of JAX is a short paper by Roy Frostig, Matthew Johnson, and Chris Leary, "Compiling machine learning programs via high-level tracing," presented at the inaugural SysML conference (later renamed MLSys) in February 2018. The paper described JAX as "a domain-specific tracing JIT compiler for generating high-performance accelerator code from pure Python and Numpy machine learning programs"[9]. The key insight was that tracing in JAX is high-level: it is implemented as user-level code within the source language rather than as part of the source language's implementation, with the trace primitives being library-level numerical functions on array-level data, like matrix multiplications, convolutions, and reductions[9].
The full open-source release came in December 2018. The core team grew to include James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang, with broader contributions from across Google and the open-source community[2]. Following the April 2023 merger of Google Brain and DeepMind into Google DeepMind, JAX development continued with backing from the combined organization.
In September 2024 the project moved its GitHub repository from google/jax to jax-ml/jax. The maintainers explained that the move was infrastructural rather than governance-related: the jax-ml organization, which had previously hosted related projects such as ml_dtypes and jax-triton, offered better continuous-integration tooling, org-wide self-hosted runners, broader hardware test coverage including more TPU and GPU configurations, and improved security features. Existing forks and URLs were redirected automatically, and the maintainers noted that the move did not reflect any change in how JAX was developed[10].
By May 2026 the canonical JAX repository at github.com/jax-ml/jax recorded roughly 35,700 stars and 3,600 forks, and the package was averaging on the order of 17.8 million PyPI downloads per month[2][11].
| Year | Milestone |
|---|---|
| 2015 | Autograd published at ICML AutoML Workshop[6] |
| Feb 2018 | "Compiling machine learning programs via high-level tracing" presented at SysML[9] |
| Dec 2018 | Initial open-source release of JAX[2] |
| Dec 2020 | Google DeepMind blog post announces broad internal adoption and the Haiku/Optax/RLax/Chex/Jraph ecosystem[4] |
| Apr 2022 | PaLM (540B parameters) trained with JAX on 6,144 TPU v4 chips via Pathways[3] |
| 2023 | Gemini 1 launched, trained in JAX on TPUs[5] |
| Sep 2024 | GitHub repository moves from google/jax to jax-ml/jax[10] |
| 2024 | Flax NNX API released; Penzai and Treescope released by Google DeepMind[12][13] |
| Jul 2025 | JAX 0.7.0 ships, with the Shardy partitioner becoming the default[14] |
| 2025 | Gemini 3 trained entirely on JAX and TPUs, per Google DeepMind chief scientist Jeff Dean[5] |
| Oct 2025 | Anthropic announces expanded Google Cloud TPU usage of up to one million chips[15] |
| Apr 2026 | JAX 0.10.0 removes the legacy C++ pmap infrastructure[14] |
| May 2026 | JAX 0.10.1 ships with new linear algebra constructors and reorganized RNG dtypes[14] |
JAX's central design principle is that powerful capabilities can be expressed as composable transformations of pure Python functions. Rather than building these capabilities into a framework's runtime (as PyTorch does with autograd) or requiring a special graph language (as early TensorFlow did), JAX provides a small set of function transformations that can be freely combined[1][9].
The grad transformation takes a Python function and returns a new function that computes its gradient. JAX supports both forward-mode and reverse-mode automatic differentiation, as well as higher-order derivatives. Crucially, grad is itself composable: developers can take the gradient of a gradient to compute Hessians, or combine grad with other transformations to express Jacobians, Hessian-vector products, and other higher-order quantities[1][16].
import jax
import jax.numpy as jnp
def loss_fn(params, x, y):
predictions = jnp.dot(x, params)
return jnp.mean((predictions - y) ** 2)
# grad returns a function computing the gradient w.r.t. first argument
grad_fn = jax.grad(loss_fn)
# Hessian via composed differentiation
hessian_fn = jax.jacfwd(jax.jacrev(loss_fn))
JAX's automatic differentiation works by tracing the function and building a representation of the computation, called a jaxpr, that can be differentiated symbolically. Because it operates on pure functions (functions without side effects), the differentiation is mathematically well-defined and reliable[16][17]. In December 2023 a research paper by Min Lin extended this to automatic functional differentiation, allowing JAX to differentiate higher-order functions (functionals and operators) in a manner reminiscent of the calculus of variations[18].
The jit transformation compiles a Python function using XLA, producing optimized machine code for the target hardware. JIT compilation traces the function with abstract values (called tracers) to determine the computation's structure, then hands the resulting program to XLA for optimization and code generation[1][9].
Internally the pipeline proceeds in several stages. JAX first traces the Python function to produce a jaxpr, an intermediate representation that captures the sequence of primitive operations. The jaxpr is then lowered into the StableHLO MLIR dialect, the input language to the OpenXLA compiler. XLA in turn performs hardware-independent optimizations on HLO, targets a specific backend, and generates low-level machine code, using LLVM for CPUs and GPUs and a dedicated compiler for TPUs[8][19].
XLA applies a range of optimizations, including:
The compilation overhead is incurred only on the first call (and whenever input shapes change), with subsequent calls executing the compiled code directly[1].
The vmap (vectorized map) transformation takes a function that operates on single examples and automatically vectorizes it to operate on batches. It works by pushing the loop down onto the function's primitive operations rather than executing it in Python, yielding the same machine code as a hand-written batched implementation[2]. This eliminates a common source of bugs and complexity in deep learning code, and it is particularly useful for per-sample gradient computations needed by techniques such as differentially private training and influence functions[20].
The pmap (parallel map) transformation distributes computation across multiple devices, such as multiple TPU cores or GPUs. It works by replicating the function across devices and executing it in parallel, with built-in support for collective operations like psum, pmean, and all-reduce for synchronizing gradients during distributed training[1].
More recently JAX has emphasized shard_map (sometimes called shmap) as a more flexible and lower-level alternative. While pmap replicates the computation across devices, shard_map allows explicit control over how data and computation are partitioned: developers write a function that handles a single shard of data, and shard_map constructs the full multi-device function automatically. Each shard sees only the device-local part of the data, with reduced shape relative to the global array[21][22]. Starting in JAX 0.8.0 the default jax.pmap implementation was reimplemented in terms of jit and shard_map; users could temporarily retain the legacy behavior via the JAX_PMAP_SHMAP_MERGE environment variable through January 2026, and JAX 0.10.0 in April 2026 removed the legacy C++ pmap infrastructure entirely[14][23]. The earlier xmap experimental API was likewise retired in favor of shard_map[21].
The critical insight of JAX's design is that these transformations compose freely. Developers can JIT-compile a vmapped, gradient-computing function that runs in parallel across devices, and each transformation is orthogonal to the others:
# Compose: parallel, vectorized, JIT-compiled gradient computation
parallel_batched_grad = jax.pmap(jax.vmap(jax.jit(jax.grad(loss_fn))))
This composability arises from JAX's functional programming paradigm. Because JAX functions are expected to be pure, each transformation can reason about the function independently and compose predictably with the others[17]. The typical composition order has vmap, grad, and jit as inner transformations defining the core computation logic, with pmap or shard_map as the outermost transformation orchestrating execution across multiple devices[24].
| Transformation | Purpose | Input | Output |
|---|---|---|---|
grad | Automatic differentiation | Function f(x) | Function f'(x) computing gradient |
jit | XLA compilation | Function f(x) | Compiled, optimized version of f |
vmap | Automatic vectorization | Function f(x) for single example | Batched function f(X) |
pmap | Multi-device parallelism (legacy) | Function f(x) | Function replicated across devices |
shard_map | Explicit per-device sharding | Function f(x) | Function with custom device sharding |
jacfwd, jacrev | Forward/reverse Jacobian | Function f(x) | Function computing Jacobian |
hessian | Second derivative | Scalar function f(x) | Function computing Hessian matrix |
JAX takes a deliberately opinionated approach to programming style by requiring (or strongly encouraging) functional purity. In practice, this means:
No in-place mutation. JAX arrays are immutable. Instead of x[i] = v, JAX uses x = x.at[i].set(v), which returns a new array. This requirement stems from the need for transformations like grad and jit to reason about the computation unambiguously[17].
Explicit random state. Unlike NumPy, which uses a global random state, JAX requires explicit passing of pseudo-random number generator (PRNG) keys. This makes random computations reproducible and compatible with transformations like jit and pmap, since there is no hidden state that could differ across devices or compilations[17].
Pure functions. Functions passed to JAX transformations should not read from or write to external state. All inputs should be function arguments, and all outputs should be return values. A pure function in JAX is one whose outputs depend only on its inputs and which has no side effects observable to the rest of the program[25].
PyTrees. JAX generalizes the input/output structure of functions through PyTrees, an abstraction that lets JAX handle tuples, lists, dictionaries, and other nested containers of array values uniformly. All JAX transformations can be applied to functions that accept as input and produce as output PyTrees of arrays, and jax.tree_util.tree_map allows any Python callable to be applied uniformly across the leaves of one or more PyTrees while preserving container structure[26].
This functional discipline has tradeoffs. On one hand, it makes JAX programs easier to reason about, optimize, and parallelize. On the other hand, it imposes a learning curve for developers accustomed to PyTorch's imperative, mutation-heavy style. Managing explicit state (model parameters, optimizer state, random keys) requires patterns that feel unfamiliar to many practitioners[17][27].
JAX provides jax.numpy, a near-complete reimplementation of the NumPy API that operates on JAX arrays rather than NumPy arrays. Most NumPy code can be converted to JAX simply by replacing import numpy as np with import jax.numpy as jnp. The familiar NumPy functions for array creation, linear algebra, indexing, and broadcasting work identically, but the underlying arrays are backed by XLA and can reside on GPUs or TPUs[1].
JAX also provides jax.scipy with reimplementations of many SciPy functions, extending the NumPy-compatible surface area to optimization, signal processing, linear algebra, and statistics. JAX 0.10.1 (May 2026) added new specialized matrix constructors including Hadamard, circulant, DFT, Leslie, companion, Fiedler, and Helmert matrices[14]. As of 2025, JAX is compatible with NumPy 2.x[14].
JAX's performance depends fundamentally on its compiler stack. JIT-compiled JAX programs are lowered to StableHLO, the MLIR dialect that serves as the input language of XLA and the OpenXLA compiler family. StableHLO is an operation set of approximately 100 statically shaped primitives such as addition, subtraction, matrix multiplication, and reductions, and it functions as a portability layer between ML frameworks (JAX, TensorFlow, PyTorch via PyTorch/XLA) and ML compilers (XLA, IREE)[19][28].
The OpenXLA project, which became the umbrella for XLA after Google open-sourced its compiler infrastructure, formalizes StableHLO with backward and forward compatibility guarantees: an exported StableHLO module remains consumable by compilers up to six months newer than the version of JAX that produced it[28][29]. From a JAX user's perspective the pipeline proceeds as follows: a Python function is traced into a jaxpr; the jaxpr is lowered into a StableHLO MLIR module; the module is handed to XLA, which lowers it further to HLO, optimizes the graph, and finally generates target-specific machine code (LLVM for CPUs and GPUs, dedicated codegen for TPUs)[19].
JAX also offers an explicit export and serialization system. Developers can lower a JAX function ahead-of-time, serialize the resulting StableHLO MLIR module, and ship it for compilation and execution in a separate process, on different hardware, or by an entirely different runtime such as IREE. The export system supports use_shardy_partitioner for explicit sharding specifications, and as of JAX 0.7.0 (July 2025) the Shardy partitioner is the default[14][29].
JAX provides a layered model for distributing computation across many accelerators. The atomic concepts are the device mesh and the sharding annotation.
A Mesh is a multidimensional NumPy-style array of JAX devices in which each axis carries a name, such as 'data' or 'model'. A NamedSharding pairs a mesh with a PartitionSpec that describes, for each input dimension of an array, which mesh axes it is partitioned across. For example PartitionSpec('data', 'model') says that an array's first dimension is sharded across the data axis of the mesh and its second across the model axis. None in a PartitionSpec denotes replication along that dimension[22][30].
jax.jit accepts sharding annotations on inputs and outputs and, in concert with the Shardy partitioner introduced as the default in 2025, automatically determines how to partition the rest of the computation, inserting collectives such as all-gather and reduce-scatter where necessary. This automatic sharding system effectively replaced the earlier GSPMD partitioner[14][22]. For workloads where the user wants per-shard control, jax.shard_map allows writing the device-local computation directly[21].
JAX is closely integrated with Google's Pathways system, which enables orchestrating large-scale training across many TPU pods from a single Python controller. The Pathways architecture supports two-way pod-level data parallelism: a single Python client constructs a sharded dataflow program that launches JAX/XLA work on remote servers, each comprising a TPU pod. PaLM (540B parameters) was the first large-scale demonstration of this system, training over 6,144 TPU v4 chips spread across two pods and achieving 57.8% hardware FLOPs utilization, the highest reported for large language models at the time[3].
Pallas is JAX's framework for writing custom kernels that run on TPUs and GPUs. It exposes an array-oriented Python programming model similar in spirit to OpenAI's Triton for NVIDIA GPUs, while offering cross-platform portability that Triton itself does not provide[31][32].
Pallas adds three principal extensions to JAX. Reference types (Refs) give users fine-grained control over memory access patterns and layout, enabling in-place operations rather than JAX's normal immutable-array model. Pallas primitives include new operations such as pallas.load, pallas.store, pallas.program_id, and pallas.num_programs that are not present in standard JAX. And pallas_call is a higher-order function that executes kernels across a grid, analogous to pmap or jit[31][32].
The same Pallas kernel definition can be lowered to different compilation targets. On GPUs Pallas historically lowered to Triton IR, which then compiled to NVIDIA and other GPU architectures; in newer JAX releases the recommended GPU backend is Mosaic GPU, with the legacy Triton backend retained only on a best-effort basis for Ampere and later[31][32]. On TPUs Pallas lowers to Mosaic, Google's internal TPU compiler, which handles operator fusion, tiling, and software pipelining to overlap data transfers with computation[32][33].
Pallas has become a foundation for a growing kernel ecosystem. Production libraries such as ejKernel package highly optimized Pallas (TPU) and Triton (GPU) implementations of common deep-learning operations behind a unified JAX interface[31]. Pallas is also reachable from outside JAX: PyTorch/XLA exposes Pallas as a way to write custom kernels for the TPU backend[32].
JAX intentionally provides only low-level numerical primitives. Neural network abstractions, optimizers, training loops, and domain-specific utilities are provided by a broad set of libraries.
Flax is the primary neural network library for JAX, developed by Google in close collaboration with the JAX team and released in 2020. It provides modules for defining layers, managing parameters, and organizing model code. Flax has gone through several API iterations[12]:
Flax is used in hundreds of projects both in the open-source community (including Hugging Face model implementations) and at Google, including PaLM, Imagen, Scenic, and Big Vision[12].
Haiku was developed by DeepMind as a JAX-based neural network library with a design inspired by Sonnet, DeepMind's TensorFlow library. Haiku uses hk.transform to manage parameter state implicitly within an apply function and was used extensively within DeepMind. As of July 2023 Google DeepMind recommended that new projects adopt Flax instead of Haiku, citing Flax's larger development team, broader adoption outside Alphabet, and a superset of features[4].
Optax is a gradient processing and optimization library for JAX, developed by Google DeepMind. It provides composable building blocks for constructing optimizers: gradient transformations (scaling, clipping, momentum) can be chained together to build standard optimizers like Adam, SGD, AdaGrad, or custom variants. This composable design mirrors JAX's own philosophy[4][34].
Orbax provides checkpointing and persistence utilities for JAX, designed to scale from a single device to large-scale distributed training. It supports asynchronous checkpointing (overlapping I/O with computation), multi-tier checkpointing, and a variety of storage backends. Orbax aims to unify what were previously fragmented checkpointing implementations across the JAX ecosystem[35].
Grain is a data loading library for JAX that provides deterministic data pipelines essential for reproducible large-scale training runs. It is designed to work efficiently with JAX's distributed training primitives[35].
Equinox, introduced by Patrick Kidger and Cristian Garcia at the Differentiable Programming workshop at NeurIPS 2021, takes a different design philosophy from Flax. It demonstrates that a PyTorch-like class-based approach to neural networks can be admitted in JAX without sacrificing functional programming, by representing parameterized functions as callable PyTrees and using "filtered" transformations to separate parameters from non-array attributes. Equinox is widely used in scientific computing and forms the foundation of related libraries such as Optimistix (nonlinear optimization) and Diffrax (differential equations)[36][37].
Penzai is a JAX library released by Google DeepMind in April 2024 for writing models as legible, functional PyTree data structures, with tools for visualizing, modifying, and analyzing them. It is designed for research involving model interpretation, ablating components, probing internal activations, and performing model surgery on pretrained networks[13][38].
Treescope, originally part of Penzai and later spun off as a standalone package, is a drop-in replacement for the ordinary IPython/Colab pretty-printer designed to visualize deeply-nested JAX PyTrees and arbitrary-dimensional NDArrays. Treescope supports rendering Equinox, Flax NNX, and PyTorch models, making it usable beyond the strict JAX ecosystem[13][38].
A December 2020 Google DeepMind blog post laid out the lab's broader JAX ecosystem. Beyond Haiku and Optax it includes RLax, a library of reinforcement learning building blocks; Chex, a collection of testing and assertion utilities used by library authors to verify JAX code; and Jraph, a lightweight library for graph neural networks with utilities for graph data structures and a "zoo" of reference models[4][39]. The libraries follow an "incremental buy-in" philosophy, where each can be adopted independently of the others[4].
Several higher-level frameworks target large language model training on top of JAX. MaxText, maintained under the AI-Hypercomputer GitHub organization, is a high-performance, scalable open-source LLM library written in pure Python and JAX, targeting Google Cloud TPUs and GPUs and providing reference implementations of Gemma, Llama, DeepSeek, Qwen, and Mistral with support for pre-training, supervised fine-tuning, and reinforcement learning post-training[40]. T5X is an earlier modular framework for high-performance training, evaluation, and inference of sequence models, while Pax is a JAX-based framework Google has used for very large-scale model training[40][41]. In November 2025 Google DeepMind released Simply, a minimal and scalable research codebase in JAX designed for both human and AI agent iteration on frontier LLM research, supporting Gemma, Qwen, and DeepSeek families with multi-host distributed training[42].
JAX's automatic differentiation through XLA-compiled code has made it especially attractive for scientific computing and reinforcement learning. JAX MD, introduced by Schoenholz, Cubuk, and colleagues at NeurIPS 2020, is an end-to-end differentiable molecular dynamics package that can be JIT-compiled to CPU, GPU, or TPU; it allows entire simulation trajectories to be differentiated for meta-optimization[43]. Brax is a fully differentiable rigid-body physics engine written in JAX that simulates environments at millions of physics steps per second on TPU and includes baseline reinforcement learning algorithms such as PPO, SAC, ARS, and evolutionary strategies[44]. MuJoCo XLA (MJX) is a JAX reimplementation of the MuJoCo physics engine that runs on NVIDIA and AMD GPUs, Apple Silicon, and Google Cloud TPUs[45]. JAX-Privacy, developed by Google DeepMind, is a library for differentially private machine learning that has been used to train VaultGemma, a notable differentially private LLM[46].
Google maintains the JAX AI Stack, a curated collection of these libraries with pinned versions verified to work together through integration tests. The stack includes Flax for model authoring, Optax for optimization, Grain for data loading, and Orbax for checkpointing[35].
| Library | Purpose | Maintainer |
|---|---|---|
| Flax | Neural network modules and model definition | |
| Optax | Gradient processing and optimization | Google DeepMind |
| Orbax | Checkpointing and persistence | |
| Grain | Deterministic data loading | |
| Haiku | Neural network library (legacy; use Flax for new projects) | Google DeepMind |
| Pallas | Custom kernel authoring for TPUs and GPUs | |
| RLax | Reinforcement learning building blocks | Google DeepMind |
| Jraph | Graph neural networks | Google DeepMind |
| Chex | Testing utilities | Google DeepMind |
| Equinox | Class-based NN library via callable PyTrees | Patrick Kidger and contributors |
| Penzai | Model inspection, editing, and visualization | Google DeepMind |
| Treescope | Pretty-printing for JAX PyTrees and NDArrays | Google DeepMind |
| JAX-Privacy | Differentially private ML | Google DeepMind |
| Brax | Differentiable physics for RL | |
| MJX | JAX reimplementation of MuJoCo physics | Google DeepMind |
| MaxText | LLM training reference implementation | Google (AI-Hypercomputer) |
| Simply | Minimal LLM research codebase | Google DeepMind |
JAX's combination of XLA compilation, functional purity, and multi-device parallelism has made it the framework of choice for training many of the world's largest AI models.
Google DeepMind uses JAX to train its most advanced models. PaLM, a 540-billion parameter large language model announced in April 2022, was trained using JAX on 6,144 Google TPU v4 chips spread across two pods, achieving 57.8% hardware FLOPs utilization through the Pathways system[3]. The Gemini family of multimodal models, which succeeded PaLM and powers Google's AI products, is also trained with JAX. Gemini Ultra (the original Gemini 1) used TPUv4 super pods of 4,096 chips each, with training distributed across multiple super pods in different data centers[47]. In December 2025 Google's chief scientist Jeff Dean confirmed that Gemini 3 was trained entirely using JAX on Google's TPUs[5]. Other production JAX-trained Google systems include the Gemma family of open-weight models, Imagen text-to-image models, and the Veo family of video generation models[5][12].
Several major AI companies have adopted JAX for their model training infrastructure:
In the research community JAX is popular for work requiring custom differentiation (such as physics simulations and scientific computing), meta-learning, and any setting where the functional programming model aligns well with the mathematical structure of the problem[4][36][43].
JAX and PyTorch represent fundamentally different philosophies for building ML frameworks, though both are capable of similar tasks.
| Aspect | JAX | PyTorch |
|---|---|---|
| Programming paradigm | Functional (pure functions, immutable arrays) | Imperative (mutation, object-oriented) |
| Execution model | Trace-and-compile (via XLA) | Eager by default; optional compilation (torch.compile) |
| Automatic differentiation | Function transformation (grad) | Tape-based (autograd) |
| Vectorization | Built-in (vmap) | Manual batching or torch.vmap (added later) |
| Multi-device parallelism | Built-in (shard_map, jit with sharding) | DistributedDataParallel, FSDP |
| Random number generation | Explicit PRNG keys | Global state (torch.manual_seed) |
| TPU support | Native, first-class | Via PyTorch/XLA (experimental) |
| GPU support | Via XLA (CUDA, ROCm) | Native CUDA, ROCm |
| Neural network library | Separate (Flax NNX, Equinox, Haiku) | Built-in (torch.nn) |
| Debugging | Harder (JIT compilation obscures errors; jax.debug.print required for traced code) | Easier (eager execution, standard debuggers) |
| Community size | Smaller but growing (~35.7k GitHub stars in May 2026) | Much larger (~80k+ GitHub stars) |
| Research paper adoption | ~5-10% of papers with code | ~75% of NeurIPS 2024 papers used PyTorch[51] |
| Production deployment | Less mature tooling; export via StableHLO/IREE | TorchServe, ExecuTorch, ONNX export |
PyTorch 2.x's torch.compile has narrowed the eager-versus-compiled gap on GPUs, while JAX continues to lead on TPUs and on workloads that benefit from XLA's whole-program optimization and vmap-style vectorization[51][52].
grad transformation makes it straightforward to compute Hessians, Jacobians, and other higher-order derivatives[1][18].JAX's strengths are accompanied by well-documented drawbacks.
Learning curve. The functional programming model, the requirement for pure functions, the explicit PRNG key API, and the discipline of immutable arrays all impose a substantial learning curve for developers coming from PyTorch or TensorFlow eager execution. Practitioners commonly report several months of adjustment before becoming productive[27][53].
Debugging traced code. When a function is transformed with jax.jit, the Python code is executed with abstract tracers in place of concrete arrays, and a Python print statement will simply print the tracer rather than a runtime value. JAX provides jax.debug.print and jax.debug.breakpoint, the jax_debug_nans flag for catching NaNs, and jax_disable_jit for falling back to eager execution, but debugging is more involved than in PyTorch eager mode. A common source of confusion is the TracerBoolConversionError raised when a Python if statement is given a dynamic tracer[54].
Smaller ecosystem. Although the JAX ecosystem is growing, it remains substantially smaller than PyTorch's. The Hugging Face Transformers library, the largest single repository of pre-trained models, has historically been PyTorch-first, and although Flax checkpoints exist for some models, the breadth and pace of new model support lags PyTorch[51].
Production deployment tooling. While the StableHLO export path enables deployment via IREE and other compilers, JAX has nothing equivalent to the maturity of TorchServe, ExecuTorch, or the established ONNX-based deployment paths of PyTorch[28][52].
Library fragmentation. Historically, the JAX neural network space split between Flax, Haiku, Equinox, Stax, and others. Google DeepMind's 2023 recommendation that new projects adopt Flax over Haiku has helped, but several active libraries with different philosophies coexist, which can confuse newcomers[4][36].
In 2024 Apple shipped a Metal plug-in that uses PJRT (the device-portability layer introduced under OpenXLA) to enable JAX execution on Mac GPUs. The plug-in lowers JAX programs to StableHLO and converts them to Metal Performance Shaders Graph executables. On an M2 Max MacBook Pro, training common networks in JAX achieved up to 28x speedups, with an average of 10x over a CPU baseline. The plug-in is officially experimental and does not support every JAX feature[50]. Community projects such as jax-mps and applejax provide alternative MLX-based or MPS-based backends, with reported speedups around 3-4x for ResNet18 training on CIFAR-10 on an M4 MacBook Air[50].
JAX continued to mature through 2025 and 2026 as the canonical framework for large-scale model training at Google and several major external labs.
Version progress. JAX shipped 0.7.0 in July 2025 (making the Shardy partitioner the default), 0.8.0 in late 2025 (reimplementing pmap on top of jit and shard_map), 0.9.0 in January 2026 (introducing jax.thread_guard and new export sharding serialization), and 0.10.0 in April 2026 (removing the legacy C++ pmap infrastructure and adding ResizeMethod.CUBIC_PYTORCH and richer LAPACK batch parallelism). JAX 0.10.1 (May 20, 2026) added new linear algebra constructors (Hadamard, circulant, DFT, Leslie, companion, Fiedler, Helmert) and reorganized the random number generation API[14].
Hardware support. The Cloud TPU lineup that JAX targets natively expanded to include TPU v5e and v5p (2023), Trillium / TPU v6e (2024), and Ironwood / TPU7x (2025) with 4,614 FP8 TFLOPS per chip, 192 GB of HBM3E, 7.37 TB/s memory bandwidth, and 9.6 Tb/s ICI bandwidth. JAX is supported as a first-class framework on every TPU generation[55].
Google I/O 2025. JAX was featured prominently at Google I/O 2025, with technical sessions demonstrating its use across Google's AI product stack[56].
Gemini 3 confirmation. In late 2025 Google DeepMind's chief scientist Jeff Dean publicly confirmed that Gemini 3 was trained entirely on JAX and TPUs, reinforcing JAX's position as Google's primary frontier training framework[5].
Anthropic expansion. Anthropic's October 2025 announcement that it would expand to up to one million TPU chips for Claude training (with a gigawatt of capacity coming online in 2026) further entrenched JAX as the framework underpinning the largest training runs outside Google[15].
Pallas ecosystem. The Pallas kernel ecosystem matured significantly. Mosaic GPU became the recommended GPU backend for Pallas, while Mosaic remained the TPU backend. Production kernel libraries such as ejKernel packaged optimized Pallas (TPU) and Triton (GPU) implementations of attention, normalization, and other common operations behind a unified JAX API[31][32].
Adoption trajectory. JAX's overall adoption remained smaller than PyTorch's, but its influence in the large-scale training space continued to grow. Multiple frontier AI labs (Google DeepMind, Anthropic, xAI) rely on JAX for their most compute-intensive training runs, and the framework's focus on composability and performance at scale positioned it well for the continuing trend toward ever-larger models trained on ever-larger clusters[5][49]. PyTorch still dominated NeurIPS-style research-paper adoption (roughly three-quarters of papers with code), but JAX retained a disproportionately strong presence in scientific computing, reinforcement learning research, and frontier LLM pretraining[51].