# JAX

> Source: https://aiwiki.ai/wiki/jax
> Updated: 2026-06-20
> Categories: Deep Learning, Machine Learning
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

JAX is a high-performance numerical computing library developed by Google that combines a [NumPy](/wiki/numpy)-compatible interface with composable function transformations for automatic differentiation, vectorization, parallelization, and just-in-time compilation. Built on top of [XLA](/wiki/xla) (Accelerated Linear Algebra), JAX enables Python code to run efficiently on CPUs, GPUs, and Google [Tensor Processing Units](/wiki/tpu). Originally created by James Bradbury, Roy Frostig, Matthew Johnson, and colleagues at [Google Brain](/wiki/google_brain), a nascent version of JAX was described at SysML 2018 and the full open-source release came in December of that year[^1][^2]. By 2026 JAX had become the framework of choice for training many of the world's largest AI models, including Google's [PaLM](/wiki/palm), [Gemini](/wiki/gemini), [Gemma](/wiki/gemma), [Imagen](/wiki/imagen), and [Veo](/wiki/veo) families, as well as [Anthropic](/wiki/anthropic)'s [Claude](/wiki/claude) models and [xAI](/wiki/xai)'s [Grok](/wiki/grok) systems[^3][^4][^5]. JAX is free and open source under the Apache 2.0 license, is hosted at github.com/jax-ml/jax where it had roughly 35,800 stars by mid-2026, and is one of the most downloaded scientific Python packages, averaging on the order of 17.8 million PyPI downloads per month[^2][^11].

## What is JAX used for?

JAX is used primarily for two things: training and running large-scale machine learning models on accelerators, and high-performance scientific and numerical computing that benefits from automatic differentiation. The DeepMind team summarized its appeal in 2020, writing that "JAX resonates well with our engineering philosophy and has been widely adopted by our research community over the last year" and that an increasing number of projects were "well served by JAX"[^4]. In practice JAX underpins frontier [large language model](/wiki/large_language_model) pretraining (Gemini, Claude, Grok), open-weight model families (Gemma), text-to-image and video generation (Imagen, Veo), and a large body of reinforcement learning and differentiable-physics research where the functional, composable programming model matches the mathematical structure of the problem[^4][^5][^43].

## History and origins

### Intellectual roots in Autograd

JAX's design draws directly from Autograd, a Python library for automatic differentiation of native Python and NumPy code created by Dougal Maclaurin, David Duvenaud, Matt Johnson, and Ryan Adams within the Harvard Intelligent Probabilistic Systems Group (HIPS). Maclaurin et al. first published "Autograd: Effortless gradients in numpy" at the ICML 2015 AutoML Workshop, demonstrating that it was possible to differentiate through arbitrary Python control flow, including loops, branches, recursion, and closures, while supporting both reverse-mode and forward-mode differentiation composed to arbitrary order[^6][^7]. Several of Autograd's principal authors, including Maclaurin and Johnson, later moved to Google and became core JAX developers, and JAX is often described as Autograd's successor[^7].

The second intellectual lineage is XLA, a domain-specific compiler originally developed for [TensorFlow](/wiki/tensorflow) that optimizes linear algebra computations for various hardware backends. XLA performs whole-program optimization, including operator fusion, memory layout optimization, and hardware-specific code generation for GPUs and TPUs. By targeting XLA as its compilation backend, JAX inherits powerful optimization capabilities without requiring users to write hardware-specific code[^1][^8].

The name "JAX" is sometimes expanded as "Just After eXecution," though the official project describes it simply as "composable transformations of Python+NumPy programs"[^2].

### The 2018 SysML paper

The earliest published description of JAX is a short paper by Roy Frostig, Matthew Johnson, and Chris Leary, "Compiling machine learning programs via high-level tracing," presented at the inaugural SysML conference (later renamed MLSys) in February 2018. The paper described JAX as "a domain-specific tracing JIT compiler for generating high-performance accelerator code from pure Python and Numpy machine learning programs"[^9]. The key insight was that tracing in JAX is high-level: it is implemented as user-level code within the source language rather than as part of the source language's implementation, with the trace primitives being library-level numerical functions on array-level data, like matrix multiplications, convolutions, and reductions[^9].

### Development at Google Brain and Google DeepMind

The full open-source release came in December 2018. The core team grew to include James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang, with broader contributions from across Google and the open-source community[^2]. Following the April 2023 merger of Google Brain and DeepMind into [Google DeepMind](/wiki/google_deepmind), JAX development continued with backing from the combined organization.

In September 2024 the project moved its GitHub repository from `google/jax` to `jax-ml/jax`. The maintainers explained that the move was infrastructural rather than governance-related: the `jax-ml` organization, which had previously hosted related projects such as `ml_dtypes` and `jax-triton`, offered better continuous-integration tooling, org-wide self-hosted runners, broader hardware test coverage including more TPU and GPU configurations, and improved security features. Existing forks and URLs were redirected automatically, and the maintainers noted that the move did not reflect any change in how JAX was developed[^10].

By May 2026 the canonical JAX repository at `github.com/jax-ml/jax` recorded roughly 35,800 stars and 3,600 forks, and the package was averaging on the order of 17.8 million PyPI downloads per month[^2][^11].

| Year | Milestone |
|---|---|
| 2015 | Autograd published at ICML AutoML Workshop[^6] |
| Feb 2018 | "Compiling machine learning programs via high-level tracing" presented at SysML[^9] |
| Dec 2018 | Initial open-source release of JAX[^2] |
| Dec 2020 | Google DeepMind blog post announces broad internal adoption and the Haiku/Optax/RLax/Chex/Jraph ecosystem[^4] |
| Apr 2022 | PaLM (540B parameters) trained with JAX on 6,144 TPU v4 chips via Pathways[^3] |
| 2023 | Gemini 1 launched, trained in JAX on TPUs[^5] |
| Sep 2024 | GitHub repository moves from `google/jax` to `jax-ml/jax`[^10] |
| 2024 | Flax NNX API released; Penzai and Treescope released by Google DeepMind[^12][^13] |
| Feb 2025 | Jeff Dean states Gemini training relies heavily on the JAX software stack and TPUs[^57] |
| Jul 2025 | JAX 0.7.0 ships, with the Shardy partitioner becoming the default[^14] |
| 2025 | Gemini 3 trained entirely on JAX and TPUs, per Google DeepMind chief scientist Jeff Dean[^5] |
| Oct 2025 | Anthropic announces expanded Google Cloud TPU usage of up to one million chips[^15] |
| Apr 2026 | JAX 0.10.0 removes the legacy C++ `pmap` infrastructure[^14] |
| May 2026 | JAX 0.10.1 ships with new linear algebra constructors and reorganized RNG dtypes[^14] |

## Core design: composable function transformations

JAX's central design principle is that powerful capabilities can be expressed as composable transformations of pure Python functions. Rather than building these capabilities into a framework's runtime (as PyTorch does with autograd) or requiring a special graph language (as early TensorFlow did), JAX provides a small set of function transformations that can be freely combined[^1][^9].

### grad: automatic differentiation

The `grad` transformation takes a Python function and returns a new function that computes its gradient. JAX supports both forward-mode and reverse-mode [automatic differentiation](/wiki/automatic_differentiation), as well as higher-order derivatives. Crucially, `grad` is itself composable: developers can take the gradient of a gradient to compute Hessians, or combine `grad` with other transformations to express Jacobians, Hessian-vector products, and other higher-order quantities[^1][^16].

```python
import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    predictions = jnp.dot(x, params)
    return jnp.mean((predictions - y) ** 2)

# grad returns a function computing the gradient w.r.t. first argument
grad_fn = jax.grad(loss_fn)

# Hessian via composed differentiation
hessian_fn = jax.jacfwd(jax.jacrev(loss_fn))
```

JAX's automatic differentiation works by tracing the function and building a representation of the computation, called a jaxpr, that can be differentiated symbolically. Because it operates on pure functions (functions without side effects), the differentiation is mathematically well-defined and reliable[^16][^17]. In December 2023 a research paper by Min Lin extended this to automatic functional differentiation, allowing JAX to differentiate higher-order functions (functionals and operators) in a manner reminiscent of the calculus of variations[^18].

### jit: just-in-time compilation

The `jit` transformation compiles a Python function using XLA, producing optimized machine code for the target hardware. JIT compilation traces the function with abstract values (called tracers) to determine the computation's structure, then hands the resulting program to XLA for optimization and code generation[^1][^9].

Internally the pipeline proceeds in several stages. JAX first traces the Python function to produce a jaxpr, an intermediate representation that captures the sequence of primitive operations. The jaxpr is then lowered into the StableHLO MLIR dialect, the input language to the [OpenXLA](https://openxla.org) compiler. XLA in turn performs hardware-independent optimizations on HLO, targets a specific backend, and generates low-level machine code, using LLVM for CPUs and GPUs and a dedicated compiler for TPUs[^8][^19].

XLA applies a range of optimizations, including:

- Operator fusion, combining multiple element-wise operations into a single kernel to reduce memory traffic
- Memory layout optimization, choosing data layouts that maximize hardware utilization
- Buffer reuse, minimizing memory allocations by reusing buffers when possible
- Hardware-specific code generation for NVIDIA GPUs, Google TPUs, AMD GPUs, or CPUs

The compilation overhead is incurred only on the first call (and whenever input shapes change), with subsequent calls executing the compiled code directly[^1].

### vmap: automatic vectorization

The `vmap` (vectorized map) transformation takes a function that operates on single examples and automatically vectorizes it to operate on batches. It works by pushing the loop down onto the function's primitive operations rather than executing it in Python, yielding the same machine code as a hand-written batched implementation[^2]. This eliminates a common source of bugs and complexity in [deep learning](/wiki/deep_learning) code, and it is particularly useful for per-sample gradient computations needed by techniques such as differentially private training and influence functions[^20].

### pmap and shard_map: parallelization

The `pmap` (parallel map) transformation distributes computation across multiple devices, such as multiple TPU cores or GPUs. It works by replicating the function across devices and executing it in parallel, with built-in support for collective operations like `psum`, `pmean`, and all-reduce for synchronizing gradients during [distributed training](/wiki/distributed_training)[^1].

More recently JAX has emphasized `shard_map` (sometimes called shmap) as a more flexible and lower-level alternative. While `pmap` replicates the computation across devices, `shard_map` allows explicit control over how data and computation are partitioned: developers write a function that handles a single shard of data, and `shard_map` constructs the full multi-device function automatically. Each shard sees only the device-local part of the data, with reduced shape relative to the global array[^21][^22]. Starting in JAX 0.8.0 the default `jax.pmap` implementation was reimplemented in terms of `jit` and `shard_map`; users could temporarily retain the legacy behavior via the `JAX_PMAP_SHMAP_MERGE` environment variable through January 2026, and JAX 0.10.0 in April 2026 removed the legacy C++ `pmap` infrastructure entirely[^14][^23]. The earlier `xmap` experimental API was likewise retired in favor of `shard_map`[^21].

### Composability

The critical insight of JAX's design is that these transformations compose freely. Developers can JIT-compile a vmapped, gradient-computing function that runs in parallel across devices, and each transformation is orthogonal to the others:

```python
# Compose: parallel, vectorized, JIT-compiled gradient computation
parallel_batched_grad = jax.pmap(jax.vmap(jax.jit(jax.grad(loss_fn))))
```

This composability arises from JAX's functional programming paradigm. Because JAX functions are expected to be pure, each transformation can reason about the function independently and compose predictably with the others[^17]. The typical composition order has `vmap`, `grad`, and `jit` as inner transformations defining the core computation logic, with `pmap` or `shard_map` as the outermost transformation orchestrating execution across multiple devices[^24].

| Transformation | Purpose | Input | Output |
|---|---|---|---|
| `grad` | Automatic differentiation | Function f(x) | Function f'(x) computing gradient |
| `jit` | XLA compilation | Function f(x) | Compiled, optimized version of f |
| `vmap` | Automatic vectorization | Function f(x) for single example | Batched function f(X) |
| `pmap` | Multi-device parallelism (legacy) | Function f(x) | Function replicated across devices |
| `shard_map` | Explicit per-device sharding | Function f(x) | Function with custom device sharding |
| `jacfwd`, `jacrev` | Forward/reverse Jacobian | Function f(x) | Function computing Jacobian |
| `hessian` | Second derivative | Scalar function f(x) | Function computing Hessian matrix |

## Functional programming paradigm

JAX takes a deliberately opinionated approach to programming style by requiring (or strongly encouraging) functional purity. In practice, this means:

**No in-place mutation.** JAX arrays are immutable. Instead of `x[i] = v`, JAX uses `x = x.at[i].set(v)`, which returns a new array. This requirement stems from the need for transformations like `grad` and `jit` to reason about the computation unambiguously[^17].

**Explicit random state.** Unlike NumPy, which uses a global random state, JAX requires explicit passing of pseudo-random number generator (PRNG) keys. This makes random computations reproducible and compatible with transformations like `jit` and `pmap`, since there is no hidden state that could differ across devices or compilations[^17].

**Pure functions.** Functions passed to JAX transformations should not read from or write to external state. All inputs should be function arguments, and all outputs should be return values. A pure function in JAX is one whose outputs depend only on its inputs and which has no side effects observable to the rest of the program[^25].

**PyTrees.** JAX generalizes the input/output structure of functions through PyTrees, an abstraction that lets JAX handle tuples, lists, dictionaries, and other nested containers of array values uniformly. All JAX transformations can be applied to functions that accept as input and produce as output PyTrees of arrays, and `jax.tree_util.tree_map` allows any Python callable to be applied uniformly across the leaves of one or more PyTrees while preserving container structure[^26].

This functional discipline has tradeoffs. On one hand, it makes JAX programs easier to reason about, optimize, and parallelize. On the other hand, it imposes a learning curve for developers accustomed to PyTorch's imperative, mutation-heavy style. Managing explicit state (model parameters, optimizer state, random keys) requires patterns that feel unfamiliar to many practitioners[^17][^27].

## NumPy-compatible API and the scientific stack

JAX provides `jax.numpy`, a near-complete reimplementation of the NumPy API that operates on JAX arrays rather than NumPy arrays. Most NumPy code can be converted to JAX simply by replacing `import numpy as np` with `import jax.numpy as jnp`. The familiar NumPy functions for array creation, linear algebra, indexing, and broadcasting work identically, but the underlying arrays are backed by XLA and can reside on GPUs or TPUs[^1].

JAX also provides `jax.scipy` with reimplementations of many SciPy functions, extending the NumPy-compatible surface area to optimization, signal processing, linear algebra, and statistics. JAX 0.10.1 (May 2026) added new specialized matrix constructors including Hadamard, circulant, DFT, Leslie, companion, Fiedler, and Helmert matrices[^14]. As of 2025, JAX is compatible with NumPy 2.x[^14].

## The compiler stack: XLA, HLO, StableHLO

JAX's performance depends fundamentally on its compiler stack. JIT-compiled JAX programs are lowered to StableHLO, the MLIR dialect that serves as the input language of XLA and the OpenXLA compiler family. StableHLO is an operation set of approximately 100 statically shaped primitives such as addition, subtraction, matrix multiplication, and reductions, and it functions as a portability layer between ML frameworks (JAX, TensorFlow, PyTorch via PyTorch/XLA) and ML compilers (XLA, IREE)[^19][^28].

The OpenXLA project, which became the umbrella for XLA after Google open-sourced its compiler infrastructure, formalizes StableHLO with backward and forward compatibility guarantees: an exported StableHLO module remains consumable by compilers up to six months newer than the version of JAX that produced it[^28][^29]. From a JAX user's perspective the pipeline proceeds as follows: a Python function is traced into a jaxpr; the jaxpr is lowered into a StableHLO MLIR module; the module is handed to XLA, which lowers it further to HLO, optimizes the graph, and finally generates target-specific machine code (LLVM for CPUs and GPUs, dedicated codegen for TPUs)[^19].

JAX also offers an explicit export and serialization system. Developers can lower a JAX function ahead-of-time, serialize the resulting StableHLO MLIR module, and ship it for compilation and execution in a separate process, on different hardware, or by an entirely different runtime such as IREE. The export system supports use_shardy_partitioner for explicit sharding specifications, and as of JAX 0.7.0 (July 2025) the Shardy partitioner is the default[^14][^29].

## Sharding and large-scale parallelism

JAX provides a layered model for distributing computation across many accelerators. The atomic concepts are the device mesh and the sharding annotation.

A `Mesh` is a multidimensional NumPy-style array of JAX devices in which each axis carries a name, such as `'data'` or `'model'`. A `NamedSharding` pairs a mesh with a `PartitionSpec` that describes, for each input dimension of an array, which mesh axes it is partitioned across. For example `PartitionSpec('data', 'model')` says that an array's first dimension is sharded across the `data` axis of the mesh and its second across the `model` axis. `None` in a `PartitionSpec` denotes replication along that dimension[^22][^30].

`jax.jit` accepts sharding annotations on inputs and outputs and, in concert with the Shardy partitioner introduced as the default in 2025, automatically determines how to partition the rest of the computation, inserting collectives such as all-gather and reduce-scatter where necessary. This automatic sharding system effectively replaced the earlier GSPMD partitioner[^14][^22]. For workloads where the user wants per-shard control, `jax.shard_map` allows writing the device-local computation directly[^21].

JAX is closely integrated with Google's Pathways system, which enables orchestrating large-scale training across many TPU pods from a single Python controller. The Pathways architecture supports two-way pod-level data parallelism: a single Python client constructs a sharded dataflow program that launches JAX/XLA work on remote servers, each comprising a TPU pod. PaLM (540B parameters) was the first large-scale demonstration of this system, training over 6,144 TPU v4 chips spread across two pods and achieving 57.8% hardware FLOPs utilization, the highest reported for large language models at the time[^3].

## Pallas: custom kernel authoring

Pallas is JAX's framework for writing custom kernels that run on TPUs and GPUs. It exposes an array-oriented Python programming model similar in spirit to OpenAI's [Triton](/wiki/triton) for NVIDIA GPUs, while offering cross-platform portability that Triton itself does not provide[^31][^32].

Pallas adds three principal extensions to JAX. Reference types (`Ref`s) give users fine-grained control over memory access patterns and layout, enabling in-place operations rather than JAX's normal immutable-array model. Pallas primitives include new operations such as `pallas.load`, `pallas.store`, `pallas.program_id`, and `pallas.num_programs` that are not present in standard JAX. And `pallas_call` is a higher-order function that executes kernels across a grid, analogous to `pmap` or `jit`[^31][^32].

The same Pallas kernel definition can be lowered to different compilation targets. On GPUs Pallas historically lowered to Triton IR, which then compiled to NVIDIA and other GPU architectures; in newer JAX releases the recommended GPU backend is Mosaic GPU, with the legacy Triton backend retained only on a best-effort basis for Ampere and later[^31][^32]. On TPUs Pallas lowers to Mosaic, Google's internal TPU compiler, which handles operator fusion, tiling, and software pipelining to overlap data transfers with computation[^32][^33].

Pallas has become a foundation for a growing kernel ecosystem. Production libraries such as ejKernel package highly optimized Pallas (TPU) and Triton (GPU) implementations of common deep-learning operations behind a unified JAX interface[^31]. Pallas is also reachable from outside JAX: PyTorch/XLA exposes Pallas as a way to write custom kernels for the TPU backend[^32].

## Ecosystem libraries

JAX intentionally provides only low-level numerical primitives. [Neural network](/wiki/neural_network) abstractions, optimizers, training loops, and domain-specific utilities are provided by a broad set of libraries.

### Flax

Flax is the primary neural network library for JAX, developed by Google in close collaboration with the JAX team and released in 2020. It provides modules for defining layers, managing parameters, and organizing model code. Flax has gone through several API iterations[^12]:

- **Flax Linen** was the standard API through 2023, using a functional style in which modules are defined as classes but parameter initialization and forward passes are separate, explicit calls.
- **Flax NNX**, introduced in 2024, is a simplified API that supports Python reference semantics, allowing users to express models using regular Python objects with reference sharing and mutability. This makes models easier to create, inspect, debug, and analyze[^12].

Flax is used in hundreds of projects both in the open-source community (including [Hugging Face](/wiki/hugging_face) model implementations) and at Google, including PaLM, Imagen, Scenic, and Big Vision[^12].

### Haiku

Haiku was developed by DeepMind as a JAX-based neural network library with a design inspired by Sonnet, DeepMind's TensorFlow library. Haiku uses `hk.transform` to manage parameter state implicitly within an `apply` function and was used extensively within DeepMind. As of July 2023 Google DeepMind recommended that new projects adopt Flax instead of Haiku, citing Flax's larger development team, broader adoption outside Alphabet, and a superset of features[^4].

### Optax

Optax is a gradient processing and optimization library for JAX, developed by Google DeepMind. It provides composable building blocks for constructing optimizers: gradient transformations (scaling, clipping, momentum) can be chained together to build standard optimizers like Adam, SGD, AdaGrad, or custom variants. This composable design mirrors JAX's own philosophy[^4][^34].

### Orbax

Orbax provides checkpointing and persistence utilities for JAX, designed to scale from a single device to large-scale distributed training. It supports asynchronous checkpointing (overlapping I/O with computation), multi-tier checkpointing, and a variety of storage backends. Orbax aims to unify what were previously fragmented checkpointing implementations across the JAX ecosystem[^35].

### Grain

Grain is a data loading library for JAX that provides deterministic data pipelines essential for reproducible large-scale training runs. It is designed to work efficiently with JAX's distributed training primitives[^35].

### Equinox

Equinox, introduced by Patrick Kidger and Cristian Garcia at the Differentiable Programming workshop at NeurIPS 2021, takes a different design philosophy from Flax. It demonstrates that a PyTorch-like class-based approach to neural networks can be admitted in JAX without sacrificing functional programming, by representing parameterized functions as callable PyTrees and using "filtered" transformations to separate parameters from non-array attributes. Equinox is widely used in scientific computing and forms the foundation of related libraries such as Optimistix (nonlinear optimization) and Diffrax (differential equations)[^36][^37].

### Penzai and Treescope

Penzai is a JAX library released by Google DeepMind in April 2024 for writing models as legible, functional PyTree data structures, with tools for visualizing, modifying, and analyzing them. It is designed for research involving model interpretation, ablating components, probing internal activations, and performing model surgery on pretrained networks[^13][^38].

Treescope, originally part of Penzai and later spun off as a standalone package, is a drop-in replacement for the ordinary IPython/Colab pretty-printer designed to visualize deeply-nested JAX PyTrees and arbitrary-dimensional NDArrays. Treescope supports rendering Equinox, Flax NNX, and PyTorch models, making it usable beyond the strict JAX ecosystem[^13][^38].

### DeepMind's wider ecosystem

A December 2020 Google DeepMind blog post laid out the lab's broader JAX ecosystem. Beyond Haiku and Optax it includes RLax, a library of reinforcement learning building blocks; Chex, a collection of testing and assertion utilities used by library authors to verify JAX code; and Jraph, a lightweight library for [graph neural networks](/wiki/graph_neural_network) with utilities for graph data structures and a "zoo" of reference models[^4][^39]. The libraries follow an "incremental buy-in" philosophy, where each can be adopted independently of the others[^4].

### LLM-specific training frameworks

Several higher-level frameworks target large language model training on top of JAX. MaxText, maintained under the `AI-Hypercomputer` GitHub organization, is a high-performance, scalable open-source LLM library written in pure Python and JAX, targeting Google Cloud TPUs and GPUs and providing reference implementations of Gemma, Llama, DeepSeek, Qwen, and Mistral with support for pre-training, supervised fine-tuning, and reinforcement learning post-training[^40]. T5X is an earlier modular framework for high-performance training, evaluation, and inference of sequence models, while Pax is a JAX-based framework Google has used for very large-scale model training[^40][^41]. In November 2025 Google DeepMind released Simply, a minimal and scalable research codebase in JAX designed for both human and AI agent iteration on frontier LLM research, supporting Gemma, Qwen, and DeepSeek families with multi-host distributed training[^42].

### Scientific and reinforcement learning

JAX's automatic differentiation through XLA-compiled code has made it especially attractive for scientific computing and reinforcement learning. JAX MD, introduced by Schoenholz, Cubuk, and colleagues at NeurIPS 2020, is an end-to-end differentiable molecular dynamics package that can be JIT-compiled to CPU, GPU, or TPU; it allows entire simulation trajectories to be differentiated for meta-optimization[^43]. Brax is a fully differentiable rigid-body physics engine written in JAX that simulates environments at millions of physics steps per second on TPU and includes baseline reinforcement learning algorithms such as PPO, SAC, ARS, and evolutionary strategies[^44]. MuJoCo XLA (MJX) is a JAX reimplementation of the MuJoCo physics engine that runs on NVIDIA and AMD GPUs, Apple Silicon, and Google Cloud TPUs[^45]. JAX-Privacy, developed by Google DeepMind, is a library for differentially private machine learning that has been used to train VaultGemma, a notable differentially private LLM[^46].

### The JAX AI Stack

Google maintains the JAX AI Stack, a curated collection of these libraries with pinned versions verified to work together through integration tests. The stack includes Flax for model authoring, Optax for optimization, Grain for data loading, and Orbax for checkpointing[^35].

| Library | Purpose | Maintainer |
|---|---|---|
| Flax | Neural network modules and model definition | Google |
| Optax | Gradient processing and optimization | Google DeepMind |
| Orbax | Checkpointing and persistence | Google |
| Grain | Deterministic data loading | Google |
| Haiku | Neural network library (legacy; use Flax for new projects) | Google DeepMind |
| Pallas | Custom kernel authoring for TPUs and GPUs | Google |
| RLax | Reinforcement learning building blocks | Google DeepMind |
| Jraph | Graph neural networks | Google DeepMind |
| Chex | Testing utilities | Google DeepMind |
| Equinox | Class-based NN library via callable PyTrees | Patrick Kidger and contributors |
| Penzai | Model inspection, editing, and visualization | Google DeepMind |
| Treescope | Pretty-printing for JAX PyTrees and NDArrays | Google DeepMind |
| JAX-Privacy | Differentially private ML | Google DeepMind |
| Brax | Differentiable physics for RL | Google |
| MJX | JAX reimplementation of MuJoCo physics | Google DeepMind |
| MaxText | LLM training reference implementation | Google (AI-Hypercomputer) |
| Simply | Minimal LLM research codebase | Google DeepMind |

## Training large-scale models

JAX's combination of XLA compilation, functional purity, and multi-device parallelism has made it the framework of choice for training many of the world's largest AI models.

### Google's flagship models

Google DeepMind uses JAX to train its most advanced models. PaLM, a 540-billion parameter [large language model](/wiki/large_language_model) announced in April 2022, was trained using JAX on 6,144 Google TPU v4 chips spread across two pods, achieving 57.8% hardware FLOPs utilization through the Pathways system[^3]. The Gemini family of multimodal models, which succeeded PaLM and powers Google's AI products, is also trained with JAX. Gemini Ultra (the original Gemini 1) used TPUv4 super pods of 4,096 chips each, with training distributed across multiple super pods in different data centers[^47]. Google DeepMind chief scientist Jeff Dean stated in February 2025 that "training our most capable Gemini models relies heavily on our JAX software stack + Google's TPU hardware platforms"[^57]. In December 2025 Dean confirmed that [Gemini 3](/wiki/gemini_3) was trained entirely using JAX on Google's TPUs[^5]. Other production JAX-trained Google systems include the Gemma family of open-weight models, Imagen text-to-image models, and the Veo family of video generation models[^5][^12].

### Adoption beyond Google

Several major AI companies have adopted JAX for their model training infrastructure:

- [Anthropic](/wiki/anthropic) uses JAX for training its Claude family of models. In October 2025 Anthropic announced an expanded multi-year deal with Google Cloud worth tens of billions of dollars, providing access to up to one million TPU chips and well over a gigawatt of capacity coming online in 2026, with Google's seventh-generation Ironwood TPUs joining Anthropic's existing mix of AWS Trainium and NVIDIA GPUs[^15][^48].
- [xAI](/wiki/xai) has built a custom training stack on JAX and Rust, with a JAX-based modeling and training layer providing composable parallelism primitives, and a Rust control plane orchestrating cluster operations. Grok 1.5 introduced a 128k-token context window using this custom JAX/Rust/Kubernetes training framework[^49].
- [Cohere](/wiki/cohere) uses JAX in its model development pipeline[^5].
- Apple has used JAX for certain [machine learning](/wiki/machine_learning) research projects, and JAX's Apple Silicon support (via the jax-metal plug-in) leverages the OpenXLA compiler and Metal Performance Shaders to accelerate training on M-series chips, with reported speedups of up to 28x over CPU on common networks[^5][^50].

In the research community JAX is popular for work requiring custom differentiation (such as physics simulations and scientific computing), meta-learning, and any setting where the functional programming model aligns well with the mathematical structure of the problem[^4][^36][^43].

## How does JAX differ from PyTorch?

JAX and [PyTorch](/wiki/pytorch) represent fundamentally different philosophies for building ML frameworks, though both are capable of similar tasks. The core distinction is that JAX is functional and compilation-first (pure functions transformed by `grad`, `jit`, and `vmap`, then handed to XLA), whereas PyTorch is imperative and eager-first (mutable tensors and tape-based autograd, with optional compilation through `torch.compile`).

| Aspect | JAX | PyTorch |
|---|---|---|
| Programming paradigm | Functional (pure functions, immutable arrays) | Imperative (mutation, object-oriented) |
| Execution model | Trace-and-compile (via XLA) | Eager by default; optional compilation (`torch.compile`) |
| Automatic differentiation | Function transformation (`grad`) | Tape-based (autograd) |
| Vectorization | Built-in (`vmap`) | Manual batching or `torch.vmap` (added later) |
| Multi-device parallelism | Built-in (`shard_map`, `jit` with sharding) | `DistributedDataParallel`, FSDP |
| Random number generation | Explicit PRNG keys | Global state (`torch.manual_seed`) |
| TPU support | Native, first-class | Via PyTorch/XLA (experimental) |
| GPU support | Via XLA (CUDA, ROCm) | Native [CUDA](/wiki/cuda), ROCm |
| Neural network library | Separate (Flax NNX, Equinox, Haiku) | Built-in (`torch.nn`) |
| Debugging | Harder (JIT compilation obscures errors; `jax.debug.print` required for traced code) | Easier (eager execution, standard debuggers) |
| Community size | Smaller but growing (~35.8k GitHub stars in 2026) | Much larger (~80k+ GitHub stars) |
| Research paper adoption | ~5-10% of papers with code | ~75% of NeurIPS 2024 papers used PyTorch[^51] |
| Production deployment | Less mature tooling; export via StableHLO/IREE | TorchServe, ExecuTorch, [ONNX](/wiki/onnx) export |

PyTorch 2.x's `torch.compile` has narrowed the eager-versus-compiled gap on GPUs, while JAX continues to lead on TPUs and on workloads that benefit from XLA's whole-program optimization and `vmap`-style vectorization[^51][^52].

### When JAX is a strong choice

- Training on Google TPUs. JAX's native XLA integration provides the most direct TPU support of any framework[^5].
- Large-scale distributed training with complex sharding strategies. JAX's composable parallelism primitives are more flexible than PyTorch's distributed abstractions for non-standard sharding[^22][^49].
- Research requiring custom or higher-order differentiation. The functional `grad` transformation makes it straightforward to compute Hessians, Jacobians, and other higher-order derivatives[^1][^18].
- Scientific computing. JAX's NumPy-compatible API and automatic differentiation make it natural for physics simulations, optimization, and other scientific applications[^43][^44].
- Functional programming preference. Developers who prefer functional patterns may find JAX's design more natural and less error-prone[^17].

### When PyTorch is typically the better choice

- Rapid prototyping and debugging. Eager execution makes it easy to inspect intermediate values and use standard Python debuggers[^27][^52].
- Leveraging pre-trained models. The Hugging Face ecosystem hosts over 500,000 pre-trained model checkpoints, virtually all as PyTorch files, and most reference implementations are in PyTorch[^51].
- Production deployment. PyTorch's deployment tooling (TorchServe, ExecuTorch, ONNX export) is more mature[^52].
- Teams with mixed experience levels. PyTorch's imperative style has a lower learning curve for developers not familiar with functional programming[^27][^52].

## Limitations and criticisms

JAX's strengths are accompanied by well-documented drawbacks.

**Learning curve.** The functional programming model, the requirement for pure functions, the explicit PRNG key API, and the discipline of immutable arrays all impose a substantial learning curve for developers coming from PyTorch or TensorFlow eager execution. Practitioners commonly report several months of adjustment before becoming productive[^27][^53].

**Debugging traced code.** When a function is transformed with `jax.jit`, the Python code is executed with abstract tracers in place of concrete arrays, and a Python `print` statement will simply print the tracer rather than a runtime value. JAX provides `jax.debug.print` and `jax.debug.breakpoint`, the `jax_debug_nans` flag for catching NaNs, and `jax_disable_jit` for falling back to eager execution, but debugging is more involved than in PyTorch eager mode. A common source of confusion is the `TracerBoolConversionError` raised when a Python `if` statement is given a dynamic tracer[^54].

**Smaller ecosystem.** Although the JAX ecosystem is growing, it remains substantially smaller than PyTorch's. The Hugging Face Transformers library, the largest single repository of pre-trained models, has historically been PyTorch-first, and although Flax checkpoints exist for some models, the breadth and pace of new model support lags PyTorch[^51].

**Production deployment tooling.** While the StableHLO export path enables deployment via IREE and other compilers, JAX has nothing equivalent to the maturity of TorchServe, ExecuTorch, or the established ONNX-based deployment paths of PyTorch[^28][^52].

**Library fragmentation.** Historically, the JAX neural network space split between Flax, Haiku, Equinox, Stax, and others. Google DeepMind's 2023 recommendation that new projects adopt Flax over Haiku has helped, but several active libraries with different philosophies coexist, which can confuse newcomers[^4][^36].

## Apple Silicon and Mac support

In 2024 Apple shipped a Metal plug-in that uses PJRT (the device-portability layer introduced under OpenXLA) to enable JAX execution on Mac GPUs. The plug-in lowers JAX programs to StableHLO and converts them to Metal Performance Shaders Graph executables. On an M2 Max MacBook Pro, training common networks in JAX achieved up to 28x speedups, with an average of 10x over a CPU baseline. The plug-in is officially experimental and does not support every JAX feature[^50]. Community projects such as `jax-mps` and `applejax` provide alternative MLX-based or MPS-based backends, with reported speedups around 3-4x for ResNet18 training on CIFAR-10 on an M4 MacBook Air[^50].

## Is JAX open source?

Yes. JAX is free and open source, released under the Apache License 2.0 and developed in the open at github.com/jax-ml/jax, where the repository recorded roughly 35,800 stars and 3,600 forks by mid-2026[^2][^11]. The package is distributed on PyPI as `jax` (with the platform-specific `jaxlib` providing the compiled XLA runtime) and averaged on the order of 17.8 million PyPI downloads per month, making it one of the most-downloaded scientific Python libraries[^11]. Although JAX originated and is primarily maintained at Google and Google DeepMind, it accepts external contributions and is used by independent maintainers such as Patrick Kidger (Equinox) and many open-source projects[^2][^36].

## Recent developments (2025-2026)

JAX continued to mature through 2025 and 2026 as the canonical framework for large-scale model training at Google and several major external labs.

**Version progress.** JAX shipped 0.7.0 in July 2025 (making the Shardy partitioner the default), 0.8.0 in late 2025 (reimplementing `pmap` on top of `jit` and `shard_map`), 0.9.0 in January 2026 (introducing `jax.thread_guard` and new export sharding serialization), and 0.10.0 in April 2026 (removing the legacy C++ `pmap` infrastructure and adding `ResizeMethod.CUBIC_PYTORCH` and richer LAPACK batch parallelism). JAX 0.10.1 (May 20, 2026) added new linear algebra constructors (Hadamard, circulant, DFT, Leslie, companion, Fiedler, Helmert) and reorganized the random number generation API[^14].

**Hardware support.** The Cloud TPU lineup that JAX targets natively expanded to include TPU v5e and v5p (2023), Trillium / TPU v6e (2024), and Ironwood / TPU7x (2025) with 4,614 FP8 TFLOPS per chip, 192 GB of HBM3E, 7.37 TB/s memory bandwidth, and 9.6 Tb/s ICI bandwidth. JAX is supported as a first-class framework on every TPU generation[^55].

**Google I/O 2025.** JAX was featured prominently at Google I/O 2025, with technical sessions demonstrating its use across Google's AI product stack[^56].

**Gemini 3 confirmation.** In late 2025 Google DeepMind's chief scientist Jeff Dean publicly confirmed that Gemini 3 was trained entirely on JAX and TPUs, reinforcing JAX's position as Google's primary frontier training framework. Dean had earlier noted in February 2025 that "training our most capable Gemini models relies heavily on our JAX software stack + Google's TPU hardware platforms"[^5][^57].

**Anthropic expansion.** Anthropic's October 2025 announcement that it would expand to up to one million TPU chips for Claude training (with a gigawatt of capacity coming online in 2026) further entrenched JAX as the framework underpinning the largest training runs outside Google[^15].

**Pallas ecosystem.** The Pallas kernel ecosystem matured significantly. Mosaic GPU became the recommended GPU backend for Pallas, while Mosaic remained the TPU backend. Production kernel libraries such as ejKernel packaged optimized Pallas (TPU) and Triton (GPU) implementations of attention, normalization, and other common operations behind a unified JAX API[^31][^32].

**Adoption trajectory.** JAX's overall adoption remained smaller than PyTorch's, but its influence in the large-scale training space continued to grow. Multiple frontier AI labs (Google DeepMind, Anthropic, xAI) rely on JAX for their most compute-intensive training runs, and the framework's focus on composability and performance at scale positioned it well for the continuing trend toward ever-larger models trained on ever-larger clusters[^5][^49]. PyTorch still dominated NeurIPS-style research-paper adoption (roughly three-quarters of papers with code), but JAX retained a disproportionately strong presence in scientific computing, reinforcement learning research, and frontier LLM pretraining[^51].

## See also

- [XLA](/wiki/xla)
- [Tensor Processing Unit](/wiki/tpu)
- [TensorFlow](/wiki/tensorflow)
- [PyTorch](/wiki/pytorch)
- [Automatic differentiation](/wiki/automatic_differentiation)
- [Google DeepMind](/wiki/google_deepmind)
- [Google Brain](/wiki/google_brain)
- [PaLM](/wiki/palm)
- [Gemini](/wiki/gemini)
- [Gemma](/wiki/gemma)
- [Anthropic](/wiki/anthropic)
- [xAI](/wiki/xai)
- [Distributed training](/wiki/distributed_training)
- [NumPy](/wiki/numpy)
- [Triton](/wiki/triton)
- [ONNX](/wiki/onnx)

## References

[^1]: Bradbury, J., Frostig, R., Hawkins, P., Johnson, M.J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., VanderPlas, J., Wanderman-Milne, S., and Zhang, Q., "JAX: composable transformations of Python+NumPy programs", GitHub, 2018. https://github.com/jax-ml/jax. Accessed 2026-05-24.

[^2]: jax-ml, "JAX repository README", GitHub, 2026-05. https://github.com/jax-ml/jax. Accessed 2026-05-24.

[^3]: Chowdhery, A. et al., "PaLM: Scaling Language Modeling with Pathways", arXiv:2204.02311, 2022-04-04. https://arxiv.org/abs/2204.02311. Accessed 2026-05-24.

[^4]: Google DeepMind, "Using JAX to accelerate our research", 2020-12-04. https://deepmind.google/blog/using-jax-to-accelerate-our-research/. Accessed 2026-05-24.

[^5]: IntoAI, "Google Trained Gemini 3 Entirely Using JAX on Its TPUs: Here Is Why It Matters", 2025. https://www.intoai.pub/p/google-jax-ai-stack. Accessed 2026-05-24.

[^6]: Maclaurin, D., Duvenaud, D., and Adams, R.P., "Autograd: Effortless gradients in numpy", ICML 2015 AutoML Workshop, 2015. https://indico.ijclab.in2p3.fr/event/2914/contributions/6483/subcontributions/180/attachments/6060/7185/automl-short.pdf. Accessed 2026-05-24.

[^7]: HIPS, "Autograd: Efficiently computes derivatives of NumPy code", GitHub, 2015. https://github.com/HIPS/autograd. Accessed 2026-05-24.

[^8]: OpenXLA Project, "XLA:GPU Architecture Overview", 2025. https://openxla.org/xla/gpu_architecture. Accessed 2026-05-24.

[^9]: Frostig, R., Johnson, M.J., and Leary, C., "Compiling machine learning programs via high-level tracing", SysML 2018, 2018-02. https://cs.stanford.edu/~rfrostig/pubs/jax-mlsys2018.pdf. Accessed 2026-05-24.

[^10]: jax-ml, "moving our GitHub repository to `jax-ml`", GitHub discussion #23319, 2024-09. https://github.com/jax-ml/jax/discussions/23319. Accessed 2026-05-24.

[^11]: PyPI, "jax package", 2026-05. https://pypi.org/project/jax/. Accessed 2026-05-24.

[^12]: Google, "Flax: a neural network library for JAX", GitHub, 2026. https://github.com/google/flax. Accessed 2026-05-24.

[^13]: Google DeepMind, "Penzai: a JAX research toolkit for building, editing, and visualizing neural networks", GitHub, 2024-04. https://github.com/google-deepmind/penzai. Accessed 2026-05-24.

[^14]: jax-ml, "JAX Change log", JAX documentation, 2026-05. https://docs.jax.dev/en/latest/changelog.html. Accessed 2026-05-24.

[^15]: Google Cloud, "Anthropic to Expand Use of Google Cloud TPUs and Services", press release, 2025-10-23. https://www.googlecloudpresscorner.com/2025-10-23-Anthropic-to-Expand-Use-of-Google-Cloud-TPUs-and-Services. Accessed 2026-05-24.

[^16]: jax-ml, "Automatic differentiation", JAX documentation, 2026. https://docs.jax.dev/en/latest/automatic-differentiation.html. Accessed 2026-05-24.

[^17]: jax-ml, "JAX: The Sharp Bits", JAX documentation, 2026. https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html. Accessed 2026-05-24.

[^18]: Lin, M., "Automatic Functional Differentiation in JAX", arXiv:2311.18727, 2023-11. https://arxiv.org/abs/2311.18727. Accessed 2026-05-24.

[^19]: OpenXLA Project, "StableHLO Specification", 2025. https://openxla.org/stablehlo/spec. Accessed 2026-05-24.

[^20]: jax-ml, "Key concepts", JAX documentation, 2026. https://docs.jax.dev/en/latest/key-concepts.html. Accessed 2026-05-24.

[^21]: jax-ml, "shmap (shard_map) for simple per-device code", JAX Enhancement Proposal 14273, 2026. https://docs.jax.dev/en/latest/jep/14273-shard-map.html. Accessed 2026-05-24.

[^22]: jax-ml, "Distributed arrays and automatic parallelization", JAX documentation, 2026. https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html. Accessed 2026-05-24.

[^23]: jax-ml, "JAX is changing the default jax.pmap implementation", GitHub discussion #32412, 2025. https://github.com/jax-ml/jax/discussions/32412. Accessed 2026-05-24.

[^24]: ApxML, "Combining JAX pmap with jit, grad, vmap", 2025. https://apxml.com/courses/getting-started-with-jax/chapter-5-parallelization-across-devices-pmap/combining-pmap-transformations. Accessed 2026-05-24.

[^25]: jax-ml, "Glossary of terms (pure function)", JAX documentation, 2026. https://docs.jax.dev/en/latest/glossary.html. Accessed 2026-05-24.

[^26]: jax-ml, "Pytrees", JAX documentation, 2026. https://docs.jax.dev/en/latest/pytrees.html. Accessed 2026-05-24.

[^27]: McKinney, A., "On Learning JAX: A Framework for High Performance Machine Learning", 2023-05-22. https://afmck.in/posts/2023-05-22-jax-post/. Accessed 2026-05-24.

[^28]: OpenXLA Project, "StableHLO portability and compatibility", 2025. https://openxla.org/stablehlo. Accessed 2026-05-24.

[^29]: jax-ml, "Exporting and serializing staged-out computations", JAX documentation, 2026. https://docs.jax.dev/en/latest/export/export.html. Accessed 2026-05-24.

[^30]: jax-ml, "jax.sharding module", JAX documentation, 2026. https://docs.jax.dev/en/latest/jax.sharding.html. Accessed 2026-05-24.

[^31]: jax-ml, "Pallas: a JAX kernel language", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/index.html. Accessed 2026-05-24.

[^32]: jax-ml, "Pallas Design", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/design/design.html. Accessed 2026-05-24.

[^33]: jax-ml, "Writing TPU kernels with Pallas", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/tpu/details.html. Accessed 2026-05-24.

[^34]: Google DeepMind, "Optax: a gradient processing and optimization library for JAX", GitHub, 2026. https://github.com/google-deepmind/optax. Accessed 2026-05-24.

[^35]: Google Cloud, "Building production AI on Cloud TPUs with JAX (JAX AI Stack)", 2025. https://docs.cloud.google.com/tpu/docs/jax-ai-stack. Accessed 2026-05-24.

[^36]: Kidger, P. and Garcia, C., "Equinox: neural networks in JAX via callable PyTrees and filtered transformations", arXiv:2111.00254 (Differentiable Programming workshop, NeurIPS 2021), 2021-10. https://arxiv.org/abs/2111.00254. Accessed 2026-05-24.

[^37]: Rader, J., Lyons, T., and Kidger, P., "Optimistix: modular optimisation in JAX and Equinox", arXiv:2402.09983, 2024. https://arxiv.org/pdf/2402.09983. Accessed 2026-05-24.

[^38]: Google DeepMind, "Treescope: an interactive HTML pretty-printer for machine learning research in IPython notebooks", GitHub, 2025. https://github.com/google-deepmind/treescope. Accessed 2026-05-24.

[^39]: Synced, "DeepMind Augments, Salutes the JAX Library Ecosystem", 2020-12-07. https://syncedreview.com/2020/12/07/deepmind-augments-salutes-the-jax-library-ecosystem/. Accessed 2026-05-24.

[^40]: AI-Hypercomputer, "MaxText: A simple, performant and scalable JAX LLM", GitHub, 2026. https://github.com/AI-Hypercomputer/maxtext. Accessed 2026-05-24.

[^41]: Pope, R. et al., "Scalable Training of Language Models using JAX pjit and TPUv4", arXiv:2204.06514, 2022. https://arxiv.org/pdf/2204.06514. Accessed 2026-05-24.

[^42]: Google DeepMind, "Simply: minimal and scalable research codebase in JAX", GitHub, 2025-11. https://github.com/google-deepmind/simply. Accessed 2026-05-24.

[^43]: Schoenholz, S.S. and Cubuk, E.D., "JAX, M.D.: A Framework for Differentiable Physics", NeurIPS 2020. https://papers.nips.cc/paper/2020/file/83d3d4b6c9579515e1679aca8cbc8033-Paper.pdf. Accessed 2026-05-24.

[^44]: Google, "Brax: massively parallel rigidbody physics simulation on accelerator hardware", GitHub, 2026. https://github.com/google/brax. Accessed 2026-05-24.

[^45]: DeepMind, "MuJoCo XLA (MJX)", MuJoCo documentation, 2025. https://mujoco.readthedocs.io/en/stable/mjx.html. Accessed 2026-05-24.

[^46]: McKenna, R. et al., "JAX-Privacy: A library for differentially private machine learning", arXiv:2602.17861, 2026-02-19. https://arxiv.org/abs/2602.17861. Accessed 2026-05-24.

[^47]: Wolfe, C.R., "Google Gemini: Fact or Fiction?", Deep (Learning) Focus, 2023. https://cameronrwolfe.substack.com/p/google-gemini-fact-or-fiction. Accessed 2026-05-24.

[^48]: Maginative, "Anthropic Secures 1M Google TPUs While Keeping Amazon as Primary Training Partner", 2025-10. https://www.maginative.com/article/anthropic-secures-1m-google-tpus-while-keeping-amazon-as-primary-training-partner/. Accessed 2026-05-24.

[^49]: Pandit, R., "Dissecting the xAI Training Stack: Why Grok Chose JAX + Rust", 2025. https://rajatpandit.com/ai-infrastructure/dissecting-xai-training-stack/. Accessed 2026-05-24.

[^50]: Apple Developer, "Accelerated JAX on Mac with Metal", 2024. https://developer.apple.com/metal/jax/. Accessed 2026-05-24.

[^51]: HeyTensor, "ML Framework Comparison 2026: PyTorch vs TensorFlow vs JAX", 2026. https://heytensor.com/research/ml-framework-comparison-2026.html. Accessed 2026-05-24.

[^52]: BlackthornVision, "Choosing Your AI Stack: PyTorch, TensorFlow, or JAX?", 2025. https://blackthorn-vision.com/blog/pytorch-vs-tensorflow/. Accessed 2026-05-24.

[^53]: Lechner, M., "Why We Started with JAX but Moved to PyTorch", 2025. https://mlechner.substack.com/p/why-we-started-with-jax-but-moved. Accessed 2026-05-24.

[^54]: jax-ml, "Introduction to debugging", JAX documentation, 2026. https://docs.jax.dev/en/latest/debugging.html. Accessed 2026-05-24.

[^55]: Google Cloud, "TPU7x (Ironwood)", Cloud TPU documentation, 2025. https://docs.cloud.google.com/tpu/docs/tpu7x. Accessed 2026-05-24.

[^56]: Google I/O, "JAX in action", Google I/O 2025 technical session, 2025-05. https://io.google/2025/explore/technical-session-1/. Accessed 2026-05-24.

[^57]: Dean, J. (@JeffDean), "Training our most capable Gemini models relies heavily on our JAX software stack + Google's TPU hardware platforms", X (formerly Twitter), 2025-02-04. https://x.com/JeffDean/status/1886852442815652188. Accessed 2026-06-20.

