XLA (Accelerated Linear Algebra)
Last reviewed
Sources
36 citations
Review status
Source-backed
Revision
v4 ยท 5,744 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
Sources
36 citations
Review status
Source-backed
Revision
v4 ยท 5,744 words
Add missing citations, update stale details, or suggest a clearer explanation.
XLA (Accelerated Linear Algebra) is Google's open-source machine learning compiler that takes computational graphs from frameworks such as TensorFlow, JAX, and PyTorch and transforms them into highly optimized machine code for TPUs, GPUs, and CPUs. First shipped inside TensorFlow in March 2017, where Google reported up to 50 percent speedups on internal NVIDIA GPU benchmarks, XLA is now the foundational compiler layer of the deep learning era and the centerpiece of the OpenXLA Project, a multi-vendor consortium that took stewardship of the compiler from TensorFlow on March 8, 2023.[1][7]
XLA's central premise is that the high-level operations expressed by ML frameworks (matrix multiplications, convolutions, element-wise functions, reductions, scans, gathers, scatters) can be represented in a single, hardware-independent intermediate representation called HLO (High-Level Optimizer IR), aggressively optimized through fusion, layout assignment, and scheduling, and finally lowered to vendor-specific instructions through pluggable backends. This compile-once-target-many design has made XLA the engine behind some of the largest training runs in history, including DeepMind's AlphaFold, Google's Gemini family, and most foundation-model training on TPU pods. It also powers JAX's signature functional transformations (jit, vmap, pmap, grad), provides PyTorch with a TPU backend through PyTorch/XLA, and underpins the new generation of compiler-based ML systems built on MLIR and StableHLO.
Before XLA, ML frameworks executed graphs by dispatching one operation at a time to hand-tuned kernels in libraries such as cuDNN, cuBLAS, MKL-DNN, and Eigen. This op-at-a-time model is simple but leaves substantial performance on the table. Each kernel reads its inputs from high-bandwidth memory, performs a small amount of arithmetic, and writes its outputs back, only for the next kernel to read those same outputs again. On modern accelerators, where memory bandwidth and not arithmetic is usually the bottleneck, this round trip through global memory is the single largest source of inefficiency. Compilers can fuse adjacent operations into a single kernel that keeps intermediate values in registers or shared memory, eliminating the bandwidth tax. They can also choose memory layouts, vectorize loops, and schedule communication and computation in ways no hand-written kernel can anticipate.
Google's TPU project added a second motivation. Custom accelerators evolve quickly and are too specialized to support with hand-written kernels for every operation in TensorFlow's growing op zoo. A compiler that lowers a small, well-specified IR to a TPU instruction set can absorb new ops automatically and stay in lockstep with hardware revisions. The same logic applies to GPUs and CPUs: a compiler can target many backends with one set of optimization passes and a thin per-backend code generator.
XLA was first announced as an experimental TensorFlow component on March 6, 2017, when Google described it as the "secret compiler sauce that helps TensorFlow optimize compositions of primitive ops automatically."[1] The original announcement reported up to 50 percent speedups on internal NVIDIA GPU benchmarks and previewed tfcompile, an ahead-of-time compilation tool that emitted compact binaries for mobile and embedded deployment.[1] XLA stayed inside the TensorFlow source tree for six years before being extracted into the OpenXLA Project in 2023.[7]
The heart of XLA is its high-level intermediate representation, HLO, sometimes spelled out as High-Level Optimizer IR or, in the newer terminology, High-Level Operations. HLO is a static-single-assignment (SSA) graph language whose nodes are pure functional operations on multidimensional arrays. Each HLO instruction has a well-defined shape (rank, dimensions, and element type) and a precise mathematical semantics specified in the OpenXLA documentation.[5] HLO is intentionally small: a few dozen primitive operations cover the entire surface area of the language, and complex framework-level ops are decomposed into compositions of those primitives.
The primitive HLO operation set spans several broad categories. Element-wise unary and binary operations such as add, multiply, exp, log, sin, and the comparison operators apply a scalar function pointwise across an array. Reductions collapse one or more dimensions of an array under an associative binary operator (sum, max, product, logical-or). Broadcasts duplicate values across new dimensions to align operand shapes. Slice, dynamic-slice, and concatenate manipulate ranges of indices. Reshape, transpose, and reverse rearrange data without changing values. Convolution, dot, and gather perform structured contractions or table lookups. Scatter writes values into computed positions. While, conditional, and call provide structured control flow. A handful of communication primitives (all-reduce, all-gather, reduce-scatter, collective-permute) support distributed execution across multiple devices.[5]
Because HLO is hardware-independent, framework code can be lowered to HLO once and targeted at any backend. Because HLO is functional and SSA, dataflow analysis, fusion, and rewriting passes are straightforward to express. Because HLO carries shape information statically, the compiler can choose layouts, allocate buffers, and schedule kernels without any runtime introspection. The price of this static design is rigidity, discussed in the limitations section below.
XLA's compilation pipeline takes an HLO module, applies a long sequence of optimization passes, and emits machine code for the target device. The pipeline is divided into hardware-independent and hardware-specific phases.[3]
The hardware-independent front of the pipeline performs traditional compiler optimizations adapted to ML workloads. Common subexpression elimination removes duplicated computations that frameworks often produce when they emit gradients or repeated indexing expressions. Algebraic simplification rewrites expressions according to mathematical identities (collapsing transpose-of-transpose, simplifying x*1 and x+0, fusing chains of reshapes). Constant folding evaluates expressions whose inputs are known statically. Dead code elimination removes instructions whose results are not used. Buffer analysis determines which intermediate tensors can share storage and computes the peak working-set size of the program.
Fusion is the single most consequential XLA optimization. The fusion pass clusters adjacent HLO operations into compound fusion instructions that the backend will lower to a single kernel. Element-wise fusion is the easiest case: a pointwise chain such as relu(bias + matmul(x, w)) can be turned into a single kernel that streams the matmul result through bias addition and ReLU without ever materializing the intermediate. Reduction fusion folds reductions and their producers into a single pass over the input. Loop fusion (sometimes called horizontal fusion) merges independent loops that iterate over the same dimensions. Output fusion absorbs producers of intermediates whose only consumer is a fusion. Because GPU and TPU workloads are usually memory-bound, fusion typically delivers the bulk of XLA's speedup over op-at-a-time execution.[19]
Layout assignment chooses a physical memory layout for every operand and result in the program. On TPUs, where matrix multiplications run on systolic arrays with a fixed input shape, the choice of layout (which logical dimension maps to which physical axis, how data is tiled and padded) determines whether a matmul runs at peak throughput or at a fraction of it. On GPUs, layout choices govern coalesced memory access and the use of tensor-core instructions. The layout assignment pass propagates layouts forward and backward through the program, inserts transpose operations only where necessary, and resolves conflicts using a cost model.
Scheduling orders the execution of operations within a device and across devices. The schedule must respect data dependencies, fit within memory budgets, overlap communication with computation, and (on multi-stream backends) keep multiple execution units busy. Scheduling choices interact tightly with buffer allocation: an aggressive schedule may need extra memory to hold intermediates simultaneously, while a memory-conscious schedule may serialize computations that could otherwise overlap.
The autotuning pass explores a search space of low-level parameters (tile sizes for matmul, block sizes for reductions, vector widths) and benchmarks alternatives on the target hardware to pick the fastest configuration. Because autotuning is expensive, XLA supports persisted autotuning: the results of a tuning run are written to a cache that can be reused on subsequent compilations of the same fusion.
Finally, the backend lowers the optimized HLO to executable machine code. On CPUs and GPUs, XLA emits LLVM IR and hands it off to the LLVM compiler infrastructure for register allocation, instruction selection, and final assembly.[4] On NVIDIA GPUs, the backend can also emit Triton IR for selected fusions, then invoke OpenAI's Triton compiler to produce PTX for high-performance attention and matmul kernels.[4] On TPUs, XLA uses Google's proprietary TPU compiler. On AWS Trainium and Inferentia, the AWS Neuron compiler consumes HLO via PJRT.[26] The output of compilation is an XLA Executable: a self-contained, device-specific binary plus a small runtime that owns the input and output buffers.
XLA supports two compilation modes, JIT and AOT, that trade off startup cost for runtime flexibility.
In just-in-time (JIT) mode, the framework hands HLO to XLA at runtime, the compiler optimizes and emits machine code on demand, and the resulting executable is cached for subsequent calls with the same input shapes and types. JIT is the default for both TensorFlow's @tf.function(jit_compile=True) and JAX's jax.jit. The first call pays the full compilation cost, often hundreds of milliseconds to tens of seconds, while subsequent calls run the cached executable directly. JIT compiles for the actual hardware on which the program runs, taking advantage of CPU instruction set extensions, GPU compute capabilities, or TPU generations without developer intervention.
In ahead-of-time (AOT) mode, the developer compiles a fully specified HLO program to a standalone binary at build time. TensorFlow's tfcompile tool was the first AOT front end and remains the standard way to deploy XLA-compiled inference graphs on mobile, embedded, or server platforms where the JIT runtime would be too heavy.[1] AOT eliminates startup latency, removes the XLA compiler from the production binary, and allows the resulting code to be inspected and signed like any other binary artifact.
Most modern XLA workloads use JIT, because input shapes usually stabilize after warmup and JIT compilation is amortized across millions of training steps or inference requests.
XLA is the runtime engine for JAX and an important optional backend for TensorFlow and PyTorch. Each framework integrates XLA differently.
JAX was designed from the ground up around XLA. The official documentation describes JAX as bringing together "a modified version of autograd and OpenXLA's XLA," using XLA "to compile and run your NumPy code on accelerators, like GPUs and TPUs."[13] Every JAX operation is internally implemented as a thin wrapper around an HLO primitive, and every JAX function is amenable to tracing: when a function is invoked with JAX arrays, the framework records the sequence of operations as it runs and assembles them into an HLO module.[13] The signature JAX transformations, jit, vmap, pmap, and grad, all operate at the level of these traces. jax.jit compiles a traced function to a single XLA executable. jax.vmap automatically vectorizes a function across a new batch dimension by rewriting the trace, eliminating the need for manual batching code. jax.pmap parallelizes a function across multiple devices using single-program-multiple-data (SPMD) semantics, with collective operations (psum, pmean, all-gather) compiled into XLA's communication primitives.[14] jax.grad uses functional reverse-mode automatic differentiation to produce a new function whose Jacobian-vector product or vector-Jacobian product is itself JIT-compilable. Because all four transformations compose freely, JAX users routinely write code that grad-of-jit-of-vmap-of-pmap-of-something-else, and XLA absorbs the full composition into one optimized executable.
TensorFlow historically had a more cautious relationship with XLA. Standard TensorFlow programs run in eager mode (since TensorFlow 2.0) or as op-by-op graphs and dispatch to ordinary CUDA or oneDNN kernels. XLA enters only when a function is wrapped with @tf.function(jit_compile=True) or annotated with tf.xla.experimental.compile. The TensorFlow-to-XLA bridge converts the relevant subgraph to HLO and hands it to XLA, reflecting TensorFlow's heterogeneous user base where some users need kernel-level control while others want whole-program optimization. The payoff can be dramatic: Google reported that applying XLA to TensorFlow text-generation models at Hugging Face delivered speedups of roughly 100x.[7]
PyTorch integrates XLA through the PyTorch/XLA package, which adds an XLA device type alongside CUDA and CPU.[16] Tensors placed on the XLA device record their operations into a lazy graph that is materialized and compiled when the user calls a synchronization barrier such as xm.mark_step(). PyTorch/XLA originated to give PyTorch users access to TPUs (Hugging Face's TPU training pipelines, Cloud TPU VMs, and Colab TPU runtimes all rely on it) and now also supports GPUs via PJRT.[17] Through its FSDP (Fully Sharded Data Parallel) integration, GPT-2-class transformers as large as 128 billion parameters have been trained on TPUs without leaving the PyTorch programming model.[20] PyTorch's other compiler stack, torch.compile with TorchInductor, targets Triton and C++/OpenMP rather than XLA: torch.compile dominates eager-mode acceleration on NVIDIA hardware, while PyTorch/XLA dominates TPU-backed training.
XLA targets an unusually broad spectrum of hardware. The following table summarizes the major backends and how they integrate.
| Backend | Vendor | Runtime interface | Code generator | Notes |
|---|---|---|---|---|
| TPU | Native PJRT plugin | Internal TPU compiler | First-class target since 2017; powers Gemini, Gemma, AlphaFold, PaLM. | |
| NVIDIA GPU (CUDA) | NVIDIA | PJRT plugin | LLVM NVPTX, optionally Triton | Most widely used GPU backend; supports Ampere, Hopper, Blackwell tensor cores. |
| AMD GPU (ROCm) | AMD | PJRT plugin | LLVM AMDGPU, Triton GEMM rewriter | Active development as of 2024 with Triton-based GEMM autotuner work in progress. |
| Intel GPU | Intel | OpenXLA PJRT plugin via Intel Extension for TensorFlow | LLVM SPIR-V | Used for oneAPI-based execution on Intel Max GPUs and integrated GPUs. |
| Apple GPU | Apple | PJRT plugin (Apple silicon) | Metal | Enables JAX on Apple silicon; emerging support in 2024. |
| AWS Trainium and Inferentia | Amazon Web Services | JAX Neuron plugin | AWS Neuron compiler | Lets Trainium and Inferentia be used as native JAX devices via PJRT. |
| CPU (x86-64, ARM64) | Various | Built-in | LLVM | Used for development, fallback, and lightweight inference. |
| Cerebras WSE | Cerebras | OpenXLA partner | Cerebras compiler | Wafer-scale engine support announced as part of OpenXLA partnership. |
| Graphcore IPU | Graphcore | OpenXLA partner | Graphcore Poplar | Founding OpenXLA partner; IPU codegen path. |
The key innovation that enables this breadth is PJRT, the Pluggable Runtime interface, which standardizes how frameworks discover devices, allocate buffers, transfer tensors, compile programs, and launch executables.[8] PJRT exposes both a C and a C++ API, and a hardware vendor can ship a PJRT plugin (a dynamically loaded library) that the framework discovers at runtime without recompilation.[9] PJRT was open-sourced in 2023 alongside the OpenXLA announcement and has since been adopted by NVIDIA (CUDA), Apple (Metal), Intel (Max GPUs), and AWS (Trainium and Inferentia).[8] Through PJRT, JAX can execute on Apple silicon without a single change to JAX itself, the new device showing up as just another platform.
For most of XLA's life, the compiler's source tree lived inside the TensorFlow repository, and its development was driven by Google. By 2022 it had become clear that XLA's user base extended well beyond TensorFlow, that several hardware vendors wanted a stronger seat at the table, and that the compiler needed to be modularized so pieces could be reused independently. In October 2022, Google announced its intention to spin XLA out of TensorFlow and form the OpenXLA Project, a multi-vendor consortium with neutral governance and a shared roadmap.
OpenXLA officially launched on March 8, 2023, with founding partners Alibaba, Amazon Web Services, AMD, Anyscale, Apple, Arm, Cerebras, Google, Graphcore, Hugging Face, Intel, Meta, NVIDIA, and SiFive.[7] The launch post framed the goal plainly: OpenXLA lets developers "compile and optimize models from all leading ML frameworks for efficient training and serving on a wide variety of hardware," promising "significant improvements in training time, throughput, serving latency, and, ultimately, time-to-market and compute costs."[7] The first deliverable was the migration of XLA, StableHLO, and IREE into independent GitHub repositories under the openxla organization, with public design discussions and a documented contribution process.[6] The launch announcement reported real-world performance wins including a 72 percent speedup on GPT-2 and an 88 percent speedup on Swin Transformer on NVIDIA hardware.[7]
The consortium has since added new components. Shardy (announced in 2024) is an MLIR-based tensor partitioning system that incorporates the lessons of GSPMD and PartIR.[11] Tokamax is a library of high-performance custom GPU and TPU kernels built on JAX and Pallas.[33] XProf is a cross-framework profiling tool. Hardware companies (Apple, NVIDIA, AMD, Intel, AWS) maintain their own backends in collaboration with the rest of the community.
Yes. Since the March 2023 OpenXLA launch, XLA's full source tree lives in the public openxla/xla GitHub repository under neutral, multi-vendor governance, separate from the TensorFlow codebase it originated in.[6][7] The project is licensed as open source and developed in the open, with public design discussions, a documented contribution process, and backends maintained collaboratively by the hardware vendors that join the consortium.[7] Sister projects in the same organization, including StableHLO (the portable operator set) and IREE (the edge and mobile runtime), are similarly open source.
A critical piece of the OpenXLA puzzle is StableHLO, the portable, versioned operator set that sits between ML frameworks and the XLA compiler. StableHLO was extracted from XLA's internal MHLO dialect and packaged as a standalone, backward-compatible operator set with a written specification, semantic versioning, and a long-term compatibility policy.[10] Frameworks that emit StableHLO can be confident that their programs will continue to compile correctly even as the XLA compiler evolves its internal representations. Hardware backends that consume StableHLO can be confident that the contract with frameworks will not silently break.
StableHLO is implemented as an MLIR dialect. MLIR (Multi-Level Intermediate Representation) is a compiler infrastructure originally developed at Google and now part of the LLVM project, designed for representing and transforming code across many levels of abstraction.[23] MLIR provides reusable infrastructure for IR construction, pattern rewriting, dataflow analysis, and code generation, all parameterized by application-specific dialects such as StableHLO, the LLVM dialect, the GPU dialect, the Triton dialect, the Linalg dialect, and many others.[24] By building on MLIR, the OpenXLA stack composes naturally with the rest of the LLVM ecosystem and with sister projects such as IREE, Triton, and PyTorch's lowering pipeline.
In practice, a JAX program is traced into an HLO module, that module is exported as StableHLO, the StableHLO is round-tripped through MLIR for any cross-framework analysis, and finally lowered to the target backend. The same StableHLO can be consumed by the standard XLA compiler, by IREE, or by a third-party compiler that conforms to the StableHLO contract. This interchange has made StableHLO the de facto neutral interchange format for ML compilation.
StableHLO reached its v1.0 milestone in May 2024, marking the point at which the operator set was declared production ready.[27] At v1.0 the project had a written specification for all 98 of its operations, with verifiers and type inference, and full support for dynamic shapes and quantization (including hybrid dot_general quantization) in the reference interpreter.[27] The headline change at v1.0 was the extension of its compatibility promise to 5 years of backward compatibility and 2 years of forward compatibility: a portable artifact serialized by one build of libStablehlo retains its semantics when deserialized by another build, provided the two builds are drawn from openxla/stablehlo commits less than 5 years apart (backward) or less than 2 years apart (forward).[28] These guarantees are intended to support long-lived on-device and server deployments that must tolerate version skew across annual update cycles. A recent line of research has also begun using StableHLO as a hardware-independent input for cross-architecture performance modeling of distributed ML workloads, treating it as a common front end to XLA, IREE, and custom toolchains.[31]
Large model training, especially on TPU pods, requires partitioning weights, activations, and gradients across many devices. XLA addresses this through SPMD (Single Program Multiple Data) compilation: a single program is compiled once and executed in parallel on every device, with the compiler inserting collective communication operations (all-reduce, all-gather, reduce-scatter, collective-permute) automatically based on user-supplied sharding annotations.
Google's first generation of SPMD partitioning was GSPMD, described in a 2021 paper, GSPMD: General and Scalable Parallelization for ML Computation Graphs.[12] GSPMD took user annotations on a small number of tensors and propagated sharding decisions across the rest of the program automatically, inserting communication where shardings disagreed and choosing communication patterns to minimize cost.[12] GSPMD proved scalable to thousands of TPU chips and was used to train models such as PaLM, LaMDA, and many of Google's foundation models.[12] A complementary line of work from Google DeepMind, PartIR, took an incremental, schedule-driven approach in which the user composes sharding strategies that rewrite the program IR step by step, with a simulator to validate each strategy before code generation; PartIR was published at ASPLOS 2025.[32]
Shardy, announced in 2024 and developed jointly by the GSPMD and PartIR teams, is the next-generation MLIR-based partitioning system.[11] Shardy uses an axis-based sharding representation that is more expressive than GSPMD's, supports incremental partitioning (the user can specify exactly how the program should be sharded at any point, rather than relying entirely on propagation), and includes novel handling of reshape operations that historically generated extra communication.[11] Short-term, Shardy delegates the actual partitioning work to the existing GSPMD partitioner; long-term, the OpenXLA team plans a new MLIR-native SPMD partitioner. Shardy is integrated with both JAX (through jax.sharding and the Shardy backend flag) and PyTorch/XLA (through the SPMD API).[21] In JAX, Shardy moved from an opt-in flag to the default partitioner in release 0.7.0 (July 2025), and the JAX project has stated that GSPMD will be fully removed once the migration completes, scheduled for March 2026, after which Shardy will be the only partitioner; a six-month backward-compatibility window lets models exported under GSPMD continue to load during the transition.[29][30]
The combination of XLA's whole-program optimization, PJRT's hardware abstraction, StableHLO's portability layer, and Shardy's partitioning has made it possible to write a single JAX program and run it unchanged on a single GPU, on a multi-node GPU cluster, or on a TPU pod with thousands of chips, with the compiler making most of the parallelization decisions.
XLA is not the only ML compiler, and the design space contains several distinct approaches. The following table contrasts XLA with the most prominent alternatives.
| Compiler | Origin | Frontends | IR | Targeting | Design philosophy |
|---|---|---|---|---|---|
| XLA | Google (2017), now OpenXLA | TensorFlow, JAX, PyTorch (via PyTorch/XLA) | HLO and StableHLO (MLIR) | TPU, NVIDIA and AMD and Intel and Apple GPU, CPU, AWS Trainium and Inferentia | Whole-graph optimization, fusion, and SPMD partitioning. |
| TVM | Apache project (UW SAMPL, 2017) | TensorFlow, PyTorch, ONNX, others | Relay, then TIR | NVIDIA and AMD and Intel GPU, CPU, ARM, microcontrollers, FPGA | Schedule-language approach with autotuning (AutoTVM, Ansor). |
| IREE | Google, now part of OpenXLA | TensorFlow, JAX, PyTorch via StableHLO | Linalg, Vulkan, CUDA, SPIR-V (MLIR) | Edge, mobile, embedded, server | Minimal runtime aligned with Vulkan; AOT and JIT. |
| Triton | OpenAI (2019), now Linux Foundation | Python DSL embedded in PyTorch | Triton IR (MLIR) | NVIDIA GPU primarily, AMD via plugins | Block-level kernel language; users write fused kernels in Python. |
| TorchInductor | PyTorch (2022) | PyTorch via TorchDynamo | FX graph, then Triton or C++/OpenMP | NVIDIA GPU, CPU, with Apple Metal in development | PyTorch-native compiler with define-by-run capture and symbolic shapes. |
XLA and TVM both originated in 2017 and target similar workloads, but they differ in philosophy. XLA operates at a higher level of abstraction and bakes most decisions into the compiler, while TVM exposes a schedule language that lets users (or autotuners such as AutoTVM and Ansor) specify how kernels should be tiled, fused, and scheduled. TVM's flexibility makes it popular for unusual hardware (microcontrollers, FPGAs, edge accelerators) and for research; XLA's productionization makes it dominant in cloud-scale training.
IREE shares much of XLA's stack (it consumes StableHLO, builds on MLIR, and is part of OpenXLA) but targets a different deployment scenario. IREE focuses on small and medium systems, including mobile and embedded devices, and aligns its runtime with the Vulkan compute API. The two compilers can be viewed as complementary: XLA for cloud-scale training and large-scale inference, IREE for edge inference and resource-constrained execution.
Triton is not a whole-graph compiler but a domain-specific language for writing fused GPU kernels at a higher level than CUDA. Triton kernels integrate naturally into both PyTorch and JAX and have become the standard way to author novel attention variants (FlashAttention, paged attention, sliding-window attention). XLA actually emits Triton IR for selected fusions on NVIDIA GPUs, which means Triton and XLA cooperate rather than compete in many real workloads.[4]
TorchInductor, released as part of PyTorch 2.0 in 2022, is the PyTorch-native compiler accessed through torch.compile(). Inductor leverages TorchDynamo for graph capture, AOTAutograd for the backward pass, and emits Triton kernels for GPUs and C++/OpenMP for CPUs. Published benchmarks reported a 2.27x inference and 1.41x training geometric-mean speedup on NVIDIA A100 across 180 real-world models, outperforming six other compilers including older XLA versions on the same workloads. The two ecosystems coexist within PyTorch: torch.compile dominates eager-mode acceleration on GPUs, while PyTorch/XLA dominates TPU-backed training.
Finally, modern critics of all of these compilers, including XLA's original architects, have argued that the rigid, statically-typed, fixed-operator-set design that worked well in 2017 (when ResNet-50 was the canonical workload) struggles with modern generative AI, where models routinely require custom datatypes, custom kernels, dynamic shapes, and aggressive specialization.[25] The next wave of ML compiler work, including projects like Modular's MAX and the OpenXLA team's own MLIR-based redesign, is partly a response to this critique.[25]
XLA's design choices come with trade-offs that practitioners encounter regularly.
Static shapes are the most prominent. XLA compiles programs against specific input shapes, and when the shape of an input changes, the compiler must re-trace and re-compile. In a transformer inference server where sequence lengths vary across requests, naive use of XLA can trigger a recompilation per request, with each compilation taking seconds or minutes and effectively destroying throughput. The standard workaround is to pad inputs to a small set of bucket sizes so that only a handful of shapes are ever seen, but padding wastes computation on the padding tokens. Bounded dynamic shapes (in which a dimension is variable but bounded above) are supported in some configurations but remain experimental, and the broader problem of fully dynamic shapes is the subject of active research. Compared to TorchInductor's symbolic-shape support, XLA is less forgiving.
Compilation latency is the second issue. JIT compilation is a one-time cost, but for a complex model with many code paths (training loops with grad accumulation, inference with multiple sampling strategies, evaluation with conditionals), the cumulative compilation time can dominate development cycles. The compilation cache (persisted to disk) and persisted autotuning help, but cache invalidation across XLA versions and hardware drivers can still trigger unwanted recompiles.
Debugging an HLO program is harder than debugging a Python program. When a JIT-compiled function produces wrong results, the developer has to inspect the HLO module, often using XLA's hlo-pass-pipeline tool or JAX's jax.make_jaxpr to step through the compilation pipeline pass by pass. JAX has invested heavily in debugging tools (jax.debug.print, jax.debug.breakpoint, the jax.experimental.checkify error handling library) but the impedance mismatch between Python's interactive debugging style and XLA's compiled execution remains real.
Custom kernels historically required dropping out of XLA entirely or wrapping a hand-written C++ function in a CustomCall HLO instruction. Recent Triton integration eases this on NVIDIA GPUs (XLA can call user-written Triton kernels directly) and the Pallas project for JAX provides a Python DSL for authoring custom XLA kernels, though the experience is still less ergonomic than writing a Triton kernel and dispatching it from PyTorch.[18] The 2024 and 2025 OpenXLA roadmaps emphasize dynamic shapes, faster compilation, better debugging, and tighter integration with Triton and Shardy.
XLA and the broader OpenXLA stack continued to evolve rapidly through 2025 and into 2026. Several threads are worth highlighting.
Shardy as the JAX default. As described above, the GSPMD-to-Shardy migration moved from optional to default. Shardy became the default partitioner in JAX 0.7.0 (July 2025), and the JAX team scheduled the complete removal of GSPMD for March 2026, at which point Shardy is to be the only partitioner.[29][30] In the same period JAX overhauled its parallelism primitives: release 0.8.0 (October 2025) reimplemented jax.pmap on top of jax.jit and jax.shard_map rather than the legacy pmap path, consolidating SPMD execution around the newer sharding APIs.[30]
Tokamax kernel library. OpenXLA published Tokamax, a library of high-performance custom accelerator kernels for both NVIDIA GPUs and Google TPUs, built on JAX and the Pallas kernel-authoring DSL.[33] Tokamax ships state-of-the-art implementations such as FlashAttention-style dot-product attention, gated linear units, layer normalization, ragged-dot operations for mixture-of-experts routing, and memory-efficient linear cross-entropy, and it exposes a single API backed by multiple implementations (standard XLA lowering, Mosaic-GPU, Triton, or Pallas/Mosaic-TPU) with automatic selection driven by cached autotuning results.[33] Tokamax complements Pallas: where Pallas is a tool for writing kernels, Tokamax is a curated library of kernels that can be dropped into a model for immediate speedups.[33]
TPU v7 (Ironwood) co-design. Google's seventh-generation TPU, code-named Ironwood (TPU7x), reached general availability and was co-designed with the XLA compiler and the JAX stack.[34] Ironwood is the first TPU generation with native 8-bit floating-point (FP8) support in its matrix-multiply units, and each chip carries 192 GB of HBM3e memory; the XLA-and-JAX software layer is responsible for fusing operations and partitioning work across the large Ironwood pods.[34] These hardware additions are exposed to ML code through HLO and StableHLO so that model authors can target FP8 matmuls without writing accelerator-specific code.[34]
PyTorch on TPU: PJRT, eager mode, and TorchTPU. PJRT became PyTorch/XLA's officially supported runtime, replacing the older XRT runtime and bringing broader device support.[17] PyTorch/XLA also added an experimental eager mode together with a torch_xla.compile API, narrowing the gap with native PyTorch's define-by-run style while preserving the option of whole-graph compilation.[35] Building on community feedback gathered in late 2025, Google announced TorchTPU in April 2026, a native PyTorch backend for TPUs that leans into PyTorch eager execution and torch.compile rather than the lazy-tensor model of PyTorch/XLA. TorchTPU supports Distributed Data Parallel, Fully Sharded Data Parallel v2, and PyTorch DTensor out of the box, and it uses XLA as its primary backend compiler instead of routing through TorchInductor; Google positioned it as the eventual successor to PyTorch/XLA.[36]
Hardware and toolchain support. The XLA repository remains under heavy active development on GitHub, with ongoing 2025 work to support newer toolchains and accelerators, including LLVM 21, CUDA 13 (which renamed the Thor GPU streaming-multiprocessor target), and the NVIDIA Blackwell GPU family (Spark, Thor, and GB300 parts).[6] On NVIDIA GPUs, XLA's Triton-based code generation benefits from Triton's own Blackwell support, where Triton reports up to roughly 1.5x speedups on FP16 attention relative to the Hopper generation.[6]
XLA underpins a long list of high-profile ML systems. Google's foundation models, including LaMDA, PaLM, the Gemini family, and the Gemma open-weights series, are trained and served on TPUs through XLA. DeepMind's AlphaFold, which predicts protein structure to near-experimental accuracy, runs on XLA-compiled JAX. Hugging Face's TPU training pipelines, Anthropic's Trainium-based stacks (via the JAX Neuron plugin and PJRT), and the academic projects that consume Cloud TPU credits all rely on XLA.[26] Beyond cloud training, XLA's AOT pipeline ships compiled inference graphs to Android devices through TensorFlow Lite, and PJRT plugins are bringing JAX-based inference to Apple silicon, Intel client GPUs, and edge accelerators.