XLA (Accelerated Linear Algebra) is an open-source machine learning compiler that takes computational graphs from frameworks such as tensorflow, jax, and pytorch and transforms them into highly optimized machine code for a wide range of hardware backends, including CPUs, gpu (NVIDIA, AMD, Intel, Apple), Google tensor_processing_unit_tpu, and AWS Trainium and Inferentia accelerators. Originally developed inside Google around 2017 as a domain-specific compiler for TensorFlow, XLA has grown into the foundational compiler layer of the deep_learning era and now sits at the heart of the OpenXLA Project, a multi-vendor consortium that took stewardship of the compiler from TensorFlow in March 2023.
XLA's central premise is that the high-level operations expressed by ML frameworks (matrix multiplications, convolutions, element-wise functions, reductions, scans, gathers, scatters) can be represented in a single, hardware-independent intermediate representation called HLO (High-Level Optimizer IR), aggressively optimized through fusion, layout assignment, and scheduling, and finally lowered to vendor-specific instructions through pluggable backends. This compile-once-target-many design has made XLA the engine behind some of the largest training runs in history, including DeepMind's AlphaFold, Google's Gemini family, and most foundation-model training on TPU pods. It also powers JAX's signature functional transformations (jit, vmap, pmap, grad), provides PyTorch with a TPU backend through PyTorch/XLA, and underpins the new generation of compiler-based ML systems built on MLIR and StableHLO.
Before XLA, ML frameworks executed graphs by dispatching one operation at a time to hand-tuned kernels in libraries such as cuDNN, cuBLAS, MKL-DNN, and Eigen. This op-at-a-time model is simple but leaves substantial performance on the table. Each kernel reads its inputs from high-bandwidth memory, performs a small amount of arithmetic, and writes its outputs back, only for the next kernel to read those same outputs again. On modern accelerators, where memory bandwidth and not arithmetic is usually the bottleneck, this round trip through global memory is the single largest source of inefficiency. Compilers can fuse adjacent operations into a single kernel that keeps intermediate values in registers or shared memory, eliminating the bandwidth tax. They can also choose memory layouts, vectorize loops, and schedule communication and computation in ways no hand-written kernel can anticipate.
Google's TPU project added a second motivation. Custom accelerators evolve quickly and are too specialized to support with hand-written kernels for every operation in TensorFlow's growing op zoo. A compiler that lowers a small, well-specified IR to a TPU instruction set can absorb new ops automatically and stay in lockstep with hardware revisions. The same logic applies to GPUs and CPUs: a compiler can target many backends with one set of optimization passes and a thin per-backend code generator.
XLA was first announced as an experimental TensorFlow component in March 2017, when Google described it as the "secret compiler sauce" that helps TensorFlow optimize compositions of primitive ops automatically. The original announcement reported up to 50 percent speedups on internal NVIDIA GPU benchmarks and previewed tfcompile, an ahead-of-time compilation tool that emitted compact binaries for mobile and embedded deployment. XLA stayed inside the TensorFlow source tree for six years before being extracted into the OpenXLA Project in 2023.
The heart of XLA is its high-level intermediate representation, HLO, sometimes spelled out as High-Level Optimizer IR or, in the newer terminology, High-Level Operations. HLO is a static-single-assignment (SSA) graph language whose nodes are pure functional operations on multidimensional arrays. Each HLO instruction has a well-defined shape (rank, dimensions, and element type) and a precise mathematical semantics specified in the OpenXLA documentation. HLO is intentionally small: a few dozen primitive operations cover the entire surface area of the language, and complex framework-level ops are decomposed into compositions of those primitives.
The primitive HLO operation set spans several broad categories. Element-wise unary and binary operations such as add, multiply, exp, log, sin, and the comparison operators apply a scalar function pointwise across an array. Reductions collapse one or more dimensions of an array under an associative binary operator (sum, max, product, logical-or). Broadcasts duplicate values across new dimensions to align operand shapes. Slice, dynamic-slice, and concatenate manipulate ranges of indices. Reshape, transpose, and reverse rearrange data without changing values. Convolution, dot, and gather perform structured contractions or table lookups. Scatter writes values into computed positions. While, conditional, and call provide structured control flow. A handful of communication primitives (all-reduce, all-gather, reduce-scatter, collective-permute) support distributed execution across multiple devices.
Because HLO is hardware-independent, framework code can be lowered to HLO once and targeted at any backend. Because HLO is functional and SSA, dataflow analysis, fusion, and rewriting passes are straightforward to express. Because HLO carries shape information statically, the compiler can choose layouts, allocate buffers, and schedule kernels without any runtime introspection. The price of this static design is rigidity, discussed in the limitations section below.
XLA's compilation pipeline takes an HLO module, applies a long sequence of optimization passes, and emits machine code for the target device. The pipeline is divided into hardware-independent and hardware-specific phases.
The hardware-independent front of the pipeline performs traditional compiler optimizations adapted to ML workloads. Common subexpression elimination removes duplicated computations that frameworks often produce when they emit gradients or repeated indexing expressions. Algebraic simplification rewrites expressions according to mathematical identities (collapsing transpose-of-transpose, simplifying x*1 and x+0, fusing chains of reshapes). Constant folding evaluates expressions whose inputs are known statically. Dead code elimination removes instructions whose results are not used. Buffer analysis determines which intermediate tensors can share storage and computes the peak working-set size of the program.
Fusion is the single most consequential XLA optimization. The fusion pass clusters adjacent HLO operations into compound fusion instructions that the backend will lower to a single kernel. Element-wise fusion is the easiest case: a pointwise chain such as relu(bias + matmul(x, w)) can be turned into a single kernel that streams the matmul result through bias addition and ReLU without ever materializing the intermediate. Reduction fusion folds reductions and their producers into a single pass over the input. Loop fusion (sometimes called horizontal fusion) merges independent loops that iterate over the same dimensions. Output fusion absorbs producers of intermediates whose only consumer is a fusion. Because GPU and TPU workloads are usually memory-bound, fusion typically delivers the bulk of XLA's speedup over op-at-a-time execution.
Layout assignment chooses a physical memory layout for every operand and result in the program. On TPUs, where matrix multiplications run on systolic arrays with a fixed input shape, the choice of layout (which logical dimension maps to which physical axis, how data is tiled and padded) determines whether a matmul runs at peak throughput or at a fraction of it. On GPUs, layout choices govern coalesced memory access and the use of tensor-core instructions. The layout assignment pass propagates layouts forward and backward through the program, inserts transpose operations only where necessary, and resolves conflicts using a cost model.
Scheduling orders the execution of operations within a device and across devices. The schedule must respect data dependencies, fit within memory budgets, overlap communication with computation, and (on multi-stream backends) keep multiple execution units busy. Scheduling choices interact tightly with buffer allocation: an aggressive schedule may need extra memory to hold intermediates simultaneously, while a memory-conscious schedule may serialize computations that could otherwise overlap.
The autotuning pass explores a search space of low-level parameters (tile sizes for matmul, block sizes for reductions, vector widths) and benchmarks alternatives on the target hardware to pick the fastest configuration. Because autotuning is expensive, XLA supports persisted autotuning: the results of a tuning run are written to a cache that can be reused on subsequent compilations of the same fusion.
Finally, the backend lowers the optimized HLO to executable machine code. On CPUs and GPUs, XLA emits LLVM IR and hands it off to the LLVM compiler infrastructure for register allocation, instruction selection, and final assembly. On NVIDIA GPUs, the backend can also emit Triton IR for selected fusions, then invoke OpenAI's Triton compiler to produce PTX for high-performance attention and matmul kernels. On TPUs, XLA uses Google's proprietary TPU compiler. On AWS Trainium and Inferentia, the AWS Neuron compiler consumes HLO via PJRT. The output of compilation is an XLA Executable: a self-contained, device-specific binary plus a small runtime that owns the input and output buffers.
XLA supports two compilation modes, JIT and AOT, that trade off startup cost for runtime flexibility.
In just-in-time (JIT) mode, the framework hands HLO to XLA at runtime, the compiler optimizes and emits machine code on demand, and the resulting executable is cached for subsequent calls with the same input shapes and types. JIT is the default for both TensorFlow's @tf.function(jit_compile=True) and JAX's jax.jit. The first call pays the full compilation cost, often hundreds of milliseconds to tens of seconds, while subsequent calls run the cached executable directly. JIT compiles for the actual hardware on which the program runs, taking advantage of CPU instruction set extensions, GPU compute capabilities, or TPU generations without developer intervention.
In ahead-of-time (AOT) mode, the developer compiles a fully specified HLO program to a standalone binary at build time. TensorFlow's tfcompile tool was the first AOT front end and remains the standard way to deploy XLA-compiled inference graphs on mobile, embedded, or server platforms where the JIT runtime would be too heavy. AOT eliminates startup latency, removes the XLA compiler from the production binary, and allows the resulting code to be inspected and signed like any other binary artifact.
Most modern XLA workloads use JIT, because input shapes usually stabilize after warmup and JIT compilation is amortized across millions of training steps or inference requests.
XLA is the runtime engine for jax and an important optional backend for tensorflow and pytorch. Each framework integrates XLA differently.
jax was designed from the ground up around XLA. Every JAX operation is internally implemented as a thin wrapper around an HLO primitive, and every JAX function is amenable to tracing: when a function is invoked with JAX arrays, the framework records the sequence of operations as it runs and assembles them into an HLO module. The signature JAX transformations, jit, vmap, pmap, and grad, all operate at the level of these traces. jax.jit compiles a traced function to a single XLA executable. jax.vmap automatically vectorizes a function across a new batch dimension by rewriting the trace, eliminating the need for manual batching code. jax.pmap parallelizes a function across multiple devices using single-program-multiple-data (SPMD) semantics, with collective operations (psum, pmean, all-gather) compiled into XLA's communication primitives. jax.grad uses functional reverse-mode automatic differentiation to produce a new function whose Jacobian-vector product or vector-Jacobian product is itself JIT-compilable. Because all four transformations compose freely, JAX users routinely write code that grad-of-jit-of-vmap-of-pmap-of-something-else, and XLA absorbs the full composition into one optimized executable.
tensorflow historically had a more cautious relationship with XLA. Standard TensorFlow programs run in eager mode (since TensorFlow 2.0) or as op-by-op graphs and dispatch to ordinary CUDA or oneDNN kernels. XLA enters only when a function is wrapped with @tf.function(jit_compile=True) or annotated with tf.xla.experimental.compile. The TensorFlow-to-XLA bridge converts the relevant subgraph to HLO and hands it to XLA, reflecting TensorFlow's heterogeneous user base where some users need kernel-level control while others want whole-program optimization.
pytorch integrates XLA through the PyTorch/XLA package, which adds an XLA device type alongside CUDA and CPU. Tensors placed on the XLA device record their operations into a lazy graph that is materialized and compiled when the user calls a synchronization barrier such as xm.mark_step(). PyTorch/XLA originated to give PyTorch users access to TPUs (Hugging Face's TPU training pipelines, Cloud TPU VMs, and Colab TPU runtimes all rely on it) and now also supports GPUs via PJRT. Through its FSDP (Fully Sharded Data Parallel) integration, GPT-2-class transformers as large as 128 billion parameters have been trained on TPUs without leaving the PyTorch programming model. PyTorch's other compiler stack, torch.compile with TorchInductor, targets Triton and C++/OpenMP rather than XLA: torch.compile dominates eager-mode acceleration on NVIDIA hardware, while PyTorch/XLA dominates TPU-backed training.
XLA targets an unusually broad spectrum of hardware. The following table summarizes the major backends and how they integrate.
| Backend | Vendor | Runtime interface | Code generator | Notes |
|---|---|---|---|---|
| TPU | Native PJRT plugin | Internal TPU compiler | First-class target since 2017; powers Gemini, Gemma, AlphaFold, PaLM. | |
| NVIDIA GPU (CUDA) | NVIDIA | PJRT plugin | LLVM NVPTX, optionally Triton | Most widely used GPU backend; supports Ampere, Hopper, Blackwell tensor cores. |
| AMD GPU (ROCm) | AMD | PJRT plugin | LLVM AMDGPU, Triton GEMM rewriter | Active development as of 2024 with Triton-based GEMM autotuner work in progress. |
| Intel GPU | Intel | OpenXLA PJRT plugin via Intel Extension for TensorFlow | LLVM SPIR-V | Used for oneAPI-based execution on Intel Max GPUs and integrated GPUs. |
| Apple GPU | Apple | PJRT plugin (Apple silicon) | Metal | Enables JAX on Apple silicon; emerging support in 2024. |
| AWS Trainium and Inferentia | Amazon Web Services | JAX Neuron plugin | AWS Neuron compiler | Lets Trainium and Inferentia be used as native JAX devices via PJRT. |
| CPU (x86-64, ARM64) | Various | Built-in | LLVM | Used for development, fallback, and lightweight inference. |
| Cerebras WSE | Cerebras | OpenXLA partner | Cerebras compiler | Wafer-scale engine support announced as part of OpenXLA partnership. |
| Graphcore IPU | Graphcore | OpenXLA partner | Graphcore Poplar | Founding OpenXLA partner; IPU codegen path. |
The key innovation that enables this breadth is PJRT, the Pluggable Runtime interface, which standardizes how frameworks discover devices, allocate buffers, transfer tensors, compile programs, and launch executables. PJRT exposes both a C and a C++ API, and a hardware vendor can ship a PJRT plugin (a dynamically loaded library) that the framework discovers at runtime without recompilation. PJRT was open-sourced in 2023 alongside the OpenXLA announcement and has since been adopted by NVIDIA (CUDA), Apple (Metal), Intel (Max GPUs), and AWS (Trainium and Inferentia). Through PJRT, JAX can execute on Apple silicon without a single change to JAX itself, the new device showing up as just another platform.
For most of XLA's life, the compiler's source tree lived inside the TensorFlow repository, and its development was driven by Google. By 2022 it had become clear that XLA's user base extended well beyond TensorFlow, that several hardware vendors wanted a stronger seat at the table, and that the compiler needed to be modularized so pieces could be reused independently. In October 2022, Google announced its intention to spin XLA out of TensorFlow and form the OpenXLA Project, a multi-vendor consortium with neutral governance and a shared roadmap.
OpenXLA officially launched on March 8, 2023, with founding partners Alibaba, Amazon Web Services, AMD, Anyscale, Apple, Arm, Cerebras, Google, Graphcore, Hugging Face, Intel, Meta, NVIDIA, and SiFive. The first deliverable was the migration of XLA, StableHLO, and IREE into independent GitHub repositories under the openxla organization, with public design discussions and a documented contribution process. The launch announcement reported real-world performance wins including a 72 percent speedup on GPT-2 and an 88 percent speedup on Swin Transformer on NVIDIA hardware.
The consortium has since added new components. Shardy (announced in 2024) is an MLIR-based tensor partitioning system that incorporates the lessons of GSPMD and PartIR. Tokamax is a benchmark and profiling suite. XProf is a cross-framework profiling tool. Hardware companies (Apple, NVIDIA, AMD, Intel, AWS) maintain their own backends in collaboration with the rest of the community.
A critical piece of the OpenXLA puzzle is StableHLO, the portable, versioned operator set that sits between ML frameworks and the XLA compiler. StableHLO was extracted from XLA's internal MHLO dialect and packaged as a standalone, backward-compatible operator set with a written specification, semantic versioning, and a long-term compatibility policy. Frameworks that emit StableHLO can be confident that their programs will continue to compile correctly even as the XLA compiler evolves its internal representations. Hardware backends that consume StableHLO can be confident that the contract with frameworks will not silently break.
StableHLO is implemented as an MLIR dialect. MLIR (Multi-Level Intermediate Representation) is a compiler infrastructure originally developed at Google and now part of the LLVM project, designed for representing and transforming code across many levels of abstraction. MLIR provides reusable infrastructure for IR construction, pattern rewriting, dataflow analysis, and code generation, all parameterized by application-specific dialects such as StableHLO, the LLVM dialect, the GPU dialect, the Triton dialect, the Linalg dialect, and many others. By building on MLIR, the OpenXLA stack composes naturally with the rest of the LLVM ecosystem and with sister projects such as IREE, Triton, and PyTorch's lowering pipeline.
In practice, a JAX program is traced into an HLO module, that module is exported as StableHLO, the StableHLO is round-tripped through MLIR for any cross-framework analysis, and finally lowered to the target backend. The same StableHLO can be consumed by the standard XLA compiler, by IREE, or by a third-party compiler that conforms to the StableHLO contract. This interchange has made StableHLO the de facto neutral interchange format for ML compilation.
Large model training, especially on TPU pods, requires partitioning weights, activations, and gradients across many devices. XLA addresses this through SPMD (Single Program Multiple Data) compilation: a single program is compiled once and executed in parallel on every device, with the compiler inserting collective communication operations (all-reduce, all-gather, reduce-scatter, collective-permute) automatically based on user-supplied sharding annotations.
Google's first generation of SPMD partitioning was GSPMD, described in a 2021 paper, GSPMD: General and Scalable Parallelization for ML Computation Graphs. GSPMD took user annotations on a small number of tensors and propagated sharding decisions across the rest of the program automatically, inserting communication where shardings disagreed and choosing communication patterns to minimize cost. GSPMD proved scalable to thousands of TPU chips and was used to train models such as PaLM, LaMDA, and many of Google's foundation models.
Shardy, announced in 2024 and developed jointly by the GSPMD and PartIR teams, is the next-generation MLIR-based partitioning system. Shardy uses an axis-based sharding representation that is more expressive than GSPMD's, supports incremental partitioning (the user can specify exactly how the program should be sharded at any point, rather than relying entirely on propagation), and includes novel handling of reshape operations that historically generated extra communication. Short-term, Shardy delegates the actual partitioning work to the existing GSPMD partitioner; long-term, the OpenXLA team plans a new MLIR-native SPMD partitioner. Shardy is integrated with both jax (through jax.sharding and the Shardy backend flag) and PyTorch/XLA (through the SPMD API).
The combination of XLA's whole-program optimization, PJRT's hardware abstraction, StableHLO's portability layer, and Shardy's partitioning has made it possible to write a single JAX program and run it unchanged on a single GPU, on a multi-node GPU cluster, or on a TPU pod with thousands of chips, with the compiler making most of the parallelization decisions.
XLA is not the only ML compiler, and the design space contains several distinct approaches. The following table contrasts XLA with the most prominent alternatives.
| Compiler | Origin | Frontends | IR | Targeting | Design philosophy |
|---|---|---|---|---|---|
| XLA | Google (2017), now OpenXLA | TensorFlow, JAX, PyTorch (via PyTorch/XLA) | HLO and StableHLO (MLIR) | TPU, NVIDIA and AMD and Intel and Apple GPU, CPU, AWS Trainium and Inferentia | Whole-graph optimization, fusion, and SPMD partitioning. |
| TVM | Apache project (UW SAMPL, 2017) | TensorFlow, PyTorch, ONNX, others | Relay, then TIR | NVIDIA and AMD and Intel GPU, CPU, ARM, microcontrollers, FPGA | Schedule-language approach with autotuning (AutoTVM, Ansor). |
| IREE | Google, now part of OpenXLA | TensorFlow, JAX, PyTorch via StableHLO | Linalg, Vulkan, CUDA, SPIR-V (MLIR) | Edge, mobile, embedded, server | Minimal runtime aligned with Vulkan; AOT and JIT. |
| Triton | OpenAI (2019), now Linux Foundation | Python DSL embedded in PyTorch | Triton IR (MLIR) | NVIDIA GPU primarily, AMD via plugins | Block-level kernel language; users write fused kernels in Python. |
| TorchInductor | PyTorch (2022) | PyTorch via TorchDynamo | FX graph, then Triton or C++/OpenMP | NVIDIA GPU, CPU, with Apple Metal in development | PyTorch-native compiler with define-by-run capture and symbolic shapes. |
XLA and TVM both originated in 2017 and target similar workloads, but they differ in philosophy. XLA operates at a higher level of abstraction and bakes most decisions into the compiler, while TVM exposes a schedule language that lets users (or autotuners such as AutoTVM and Ansor) specify how kernels should be tiled, fused, and scheduled. TVM's flexibility makes it popular for unusual hardware (microcontrollers, FPGAs, edge accelerators) and for research; XLA's productionization makes it dominant in cloud-scale training.
IREE shares much of XLA's stack (it consumes StableHLO, builds on MLIR, and is part of OpenXLA) but targets a different deployment scenario. IREE focuses on small and medium systems, including mobile and embedded devices, and aligns its runtime with the Vulkan compute API. The two compilers can be viewed as complementary: XLA for cloud-scale training and large-scale inference, IREE for edge inference and resource-constrained execution.
Triton is not a whole-graph compiler but a domain-specific language for writing fused GPU kernels at a higher level than CUDA. Triton kernels integrate naturally into both PyTorch and JAX and have become the standard way to author novel attention variants (FlashAttention, paged attention, sliding-window attention). XLA actually emits Triton IR for selected fusions on NVIDIA GPUs, which means Triton and XLA cooperate rather than compete in many real workloads.
TorchInductor, released as part of PyTorch 2.0 in 2022, is the PyTorch-native compiler accessed through torch.compile(). Inductor leverages TorchDynamo for graph capture, AOTAutograd for the backward pass, and emits Triton kernels for GPUs and C++/OpenMP for CPUs. Published benchmarks reported a 2.27x inference and 1.41x training geometric-mean speedup on NVIDIA A100 across 180 real-world models, outperforming six other compilers including older XLA versions on the same workloads. The two ecosystems coexist within PyTorch: torch.compile dominates eager-mode acceleration on GPUs, while PyTorch/XLA dominates TPU-backed training.
Finally, modern critics of all of these compilers, including XLA's original architects, have argued that the rigid, statically-typed, fixed-operator-set design that worked well in 2017 (when ResNet-50 was the canonical workload) struggles with modern generative AI, where models routinely require custom datatypes, custom kernels, dynamic shapes, and aggressive specialization. The next wave of ML compiler work, including projects like Modular's MAX and the OpenXLA team's own MLIR-based redesign, is partly a response to this critique.
XLA's design choices come with trade-offs that practitioners encounter regularly.
Static shapes are the most prominent. XLA compiles programs against specific input shapes, and when the shape of an input changes, the compiler must re-trace and re-compile. In a transformer inference server where sequence lengths vary across requests, naive use of XLA can trigger a recompilation per request, with each compilation taking seconds or minutes and effectively destroying throughput. The standard workaround is to pad inputs to a small set of bucket sizes so that only a handful of shapes are ever seen, but padding wastes computation on the padding tokens. Bounded dynamic shapes (in which a dimension is variable but bounded above) are supported in some configurations but remain experimental, and the broader problem of fully dynamic shapes is the subject of active research. Compared to TorchInductor's symbolic-shape support, XLA is less forgiving.
Compilation latency is the second issue. JIT compilation is a one-time cost, but for a complex model with many code paths (training loops with grad accumulation, inference with multiple sampling strategies, evaluation with conditionals), the cumulative compilation time can dominate development cycles. The compilation cache (persisted to disk) and persisted autotuning help, but cache invalidation across XLA versions and hardware drivers can still trigger unwanted recompiles.
Debugging an HLO program is harder than debugging a Python program. When a JIT-compiled function produces wrong results, the developer has to inspect the HLO module, often using XLA's hlo-pass-pipeline tool or JAX's jax.make_jaxpr to step through the compilation pipeline pass by pass. JAX has invested heavily in debugging tools (jax.debug.print, jax.debug.breakpoint, the jax.experimental.checkify error handling library) but the impedance mismatch between Python's interactive debugging style and XLA's compiled execution remains real.
Custom kernels historically required dropping out of XLA entirely or wrapping a hand-written C++ function in a CustomCall HLO instruction. Recent Triton integration eases this on NVIDIA GPUs (XLA can call user-written Triton kernels directly) and the Pallas project for JAX provides a Python DSL for authoring custom XLA kernels, though the experience is still less ergonomic than writing a Triton kernel and dispatching it from PyTorch. The 2024 and 2025 OpenXLA roadmaps emphasize dynamic shapes, faster compilation, better debugging, and tighter integration with Triton and Shardy.
XLA underpins a long list of high-profile ML systems. Google's foundation models, including LaMDA, PaLM, the Gemini family, and the Gemma open-weights series, are trained and served on TPUs through XLA. DeepMind's AlphaFold, which predicts protein structure to near-experimental accuracy, runs on XLA-compiled JAX. Hugging Face's TPU training pipelines, Anthropic's Trainium-based stacks (via the JAX Neuron plugin and PJRT), and the academic projects that consume Cloud TPU credits all rely on XLA. Beyond cloud training, XLA's AOT pipeline ships compiled inference graphs to Android devices through TensorFlow Lite, and PJRT plugins are bringing JAX-based inference to Apple silicon, Intel client GPUs, and edge accelerators.