Pallas (JAX kernel language)
Last reviewed
May 21, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v1 ยท 3,607 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 21, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v1 ยท 3,607 words
Add missing citations, update stale details, or suggest a clearer explanation.
Pallas is an experimental extension to JAX that lets users write custom hardware kernels in Python and lower them to both Tensor Processing Units and NVIDIA GPUs from a single source.[^1] It exposes a lower level of abstraction than ordinary JAX programs (Refs into on-chip memory, explicit grids, block specifications) while preserving JAX's tracing, jax.numpy API, and composable transformations such as jit, vmap, and grad.[^1][^2] On GPUs Pallas lowers to either the Triton intermediate representation or to Mosaic GPU; on TPUs it lowers to Mosaic, Google's MLIR-based TPU code generator.[^2][^3] The project is developed inside the JAX repository (jax-ml/jax) and is documented as an experimental, frequently changing component.[^1][^4]
Modern deep learning compilers such as XLA do a good job of fusing and scheduling common array operations, but production teams routinely hit cases where the compiler cannot match the throughput of a hand-written kernel. The JAX project's Pallas design note frames the problem this way: while XLA "does a good job compiling user programs," some users "inevitably hit XLA's limitations" and need an "escape hatch" that allows experts to write hand-tuned kernels that outperform XLA at that point in time.[^2] XLA already exposes a CustomCall mechanism, but, as the same design note observes, it requires writing C++ and, on GPU, learning the CUDA programming model, both of which are mismatched with the array-oriented Python style most JAX users prefer.[^2]
On the GPU side, OpenAI's Triton showed that a Python embedded DSL with an array programming model could produce highly competitive kernels without requiring developers to manage individual threads. On the TPU side, however, there was no equivalent option: writing custom TPU kernels required deep familiarity with Google's internal compiler stack. The JAX team's stated goal for Pallas is to give JAX a Triton-class kernel front end with first-class TPU support, by reusing JAX's tracing machinery and jax.numpy API as the source language and routing kernels through hardware-specific backends.[^2] Sharad Vikram, an engineer on the JAX team at Google DeepMind, is credited as the creator and lead of Pallas and gave the first public talk titled "Pallas: A JAX Kernel Language" in September 2023.[^5]
| Attribute | Value |
|---|---|
| Project name | Pallas (jax.experimental.pallas) |
| Parent project | JAX (jax-ml/jax GitHub repository) |
| Lead developer | Sharad Vikram, Google DeepMind[^5] |
| First public availability | JAX main branch by August 2023; included in JAX releases shortly thereafter[^4][^6] |
| GPU backends | Triton IR and Mosaic GPU[^1][^3] |
| TPU backend | Mosaic (TPU code generator)[^2][^3] |
| Supported TPU hardware | TPU v4 and later, including Trillium (TPU v6e) and TPU Ironwood[^7][^8] |
| Supported GPU hardware | Ampere and newer for Triton backend; Hopper and Blackwell for Mosaic GPU[^7][^3] |
| License | Apache License 2.0[^4] |
| Status | Experimental; changing frequently per official documentation[^1] |
Pallas was developed inside the google/jax repository through 2023 and merged in increments. A community discussion in August 2023 noted that Pallas "does not seem to be in 0.4.14" but was already present on JAX's main branch, with maintainers replying that Pallas had become "part of JAX," though stable installation required additional dependencies such as JAX-Triton.[^6] Sharad Vikram's public talk "Pallas: A JAX Kernel Language" followed in September 2023.[^5]
The JAX core repository moved from github.com/google/jax to github.com/jax-ml/jax in September 2024, taking Pallas with it. The JAX team explained the move by noting that the jax-ml organization had already housed related projects such as ml_dtypes and jax-triton, and that consolidating the canonical repository there better reflected ownership by the JAX core team rather than Google as a whole.[^9]
A separate Pallas changelog is maintained alongside the main JAX changelog. Notable milestones recorded there include the addition of TPU multi-buffering for pipelines, scalar prefetch and block-sparse computation, GPU atomic primitives such as atomic_add and atomic_max, the introduction of Mosaic GPU as the default lowering path on GPU (with Triton retained as an opt-in backend), and a refactor of pl.kernel to use out_type/scratch_types in JAX 0.10.0.[^4]
Outside of JAX itself, Pallas was integrated into PyTorch/XLA. The Google Cloud blog post announcing PyTorch/XLA 2.4 (published 2024-07-31) describes Pallas as "a custom kernel language originally developed for JAX" that "supports both TPUs and GPUs" and is "similar to the Triton library, but since it also works on both TPUs and GPUs, it's easier to port your model from one machine learning accelerator to the other."[^10] That release shipped PyTorch-side wrappers around Pallas implementations of Flash Attention, Paged Attention, and the MegaBlocks block-sparse Mixture-of-Experts kernel, with full autograd integration.[^10]
A Pallas kernel is an ordinary Python function that operates on Refs, mutable references to on-chip scratchpad memory rather than the immutable JAX arrays that ordinary jax.numpy code manipulates.[^1][^2] A trivial vector-addition kernel from the official quickstart looks like:[^11]
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
The kernel reads from input refs, computes with normal JAX operations, and writes to an output ref. It does not return anything because results are produced through side effects on the output ref.[^11]
To call a kernel inside a regular JAX program, users invoke pl.pallas_call, a higher-order operation analogous to jax.jit or jax.pmap that lifts the kernel function into the surrounding JAX trace:[^11][^12]
@jax.jit
def add_vectors(x, y):
return pl.pallas_call(
add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
)(x, y)
pallas_call accepts a grid, a tuple specifying an iteration space. A grid of (4, 5) launches twenty program instances, and inside the kernel pl.program_id(axis=...) returns the current grid coordinate. On GPUs program instances run in parallel; on TPUs they execute in lexicographic order by default, with optional dimension_semantics annotations to mark dimensions as parallel on multi-core TPUs.[^11][^7]
Because real kernels usually operate on tiles of large arrays rather than whole arrays at once, Pallas introduces the BlockSpec, an abstraction that combines a block_shape with an index_map from program ids to block coordinates. A two-dimensional matrix multiplication kernel can specify block shapes for X and Y and let the runtime "carve up" inputs and outputs automatically:[^11][^13]
pl.BlockSpec((x.shape<sup><a href="#cite_note-0" class="cite-ref">[0]</a></sup> // 2, x.shape<sup><a href="#cite_note-1" class="cite-ref">[1]</a></sup>), lambda i, j: (i, 0))
pl.BlockSpec((y.shape<sup><a href="#cite_note-0" class="cite-ref">[0]</a></sup>, y.shape<sup><a href="#cite_note-1" class="cite-ref">[1]</a></sup> // 2), lambda i, j: (0, j))
The Pallas design note describes BlockSpec as the abstraction that "tells Pallas how to map a program id to a slice/block of an input or output," with the grid plus BlockSpecs together describing how to tile inputs while Pallas schedules the tiles across hardware and "handles the DMA transfers automatically." Inputs entering a kernel body are already staged into the on-chip scratchpad (VMEM on TPU, shared memory on GPU) rather than living in HBM/DRAM.[^2][^13]
Pallas exposes the memory hierarchy of accelerators explicitly. On TPUs the kernel body operates on VMEM (16 MB and up on recent generations) or SMEM, while communication with high-bandwidth memory is "handled by the compiler and overlapped with compute."[^7] When two consecutive grid indices use the same input slice, the HBM transfer for the second iteration is skipped because the data is already available in VMEM.[^7] On Mosaic GPU, Pallas requires explicit pipelining: the documentation notes that "pipelining in Pallas is programmed explicitly," which is a significant difference from Triton where pipelining is an automatic compiler optimization.[^3]
Vector register constraints on TPU also leak into the kernel programmer's view: vector operations use 8x128 tiles for 32-bit values, the last two axes of an array are treated differently from earlier axes, and block shapes generally require their last two dimensions to be divisible by 8 and 128.[^7]
Because pallas_call is just another JAX higher-order primitive, it composes with the rest of the JAX ecosystem. The design note states that a pallas_call is "augmented to have an extra grid dimension" when vmapped and that the BlockSpecs are transformed to index along the new batch dimension automatically.[^2] jax.grad is also supported by decomposing into the standard jvp, partial_eval, and transpose rules, and jax.jit integration is seamless because Pallas kernels are traced and staged out like any other JAX computation.[^2] An emulation mode lowers Pallas to StableHLO so that kernels can be debugged on CPU before being targeted at real accelerators.[^2]
On TPUs Pallas lowers to Mosaic, an MLIR-based code generator that consumes mostly standard MLIR dialects (chiefly vector and arith) and emits the low-level "LLO" representation that JAX's TPU stack feeds to the underlying hardware. The JAX documentation describes the compilation path explicitly: "After users express their Pallas kernels, we lower them to different representations depending on the target backend. On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to Mosaic."[^2] Pallas's TPU support targets v4 and newer, and recent TPU generations including Trillium (TPU v6e) and TPU Ironwood are part of Google's codesigned AI stack that the JAX/Pallas/XLA toolchain runs on.[^7][^8]
The original Pallas GPU backend lowered to Triton's intermediate representation. The design note observes that lowering Pallas to Triton was "straightforward because Pallas was designed with Triton as a target language in mind," and that the main translation differences are the absence of BlockSpecs in Triton and Triton's use of pointers rather than indices for memory access.[^2] The Triton backend supports NVIDIA GPUs from Ampere (compute capability 8.0) and above, though the documentation notes it is now maintained on a best-effort basis.[^7][^3]
In parallel with the TPU effort, the Pallas team built Mosaic GPU, a lower-level GPU backend that targets Hopper-class hardware (NVIDIA H100) and the newer NVIDIA Blackwell generation directly.[^3] The official reference describes Mosaic GPU as "very similar to a programming model popularized by Triton" but "more low level, which usually means you will have to put in more work, but it also puts you more in control." Pallas continues to support Triton as an alternative, with the per-call backend argument letting users pick whichever lowering best fits their needs.[^3] As of JAX 0.9.0 the default GPU lowering path moved to Mosaic GPU, with Triton kept available behind an explicit flag.[^4] Mosaic GPU exposes hardware-specific primitives such as wgmma on Hopper, tcgen05_mma for Blackwell tensor memory, and TMA-based asynchronous load/store, along with Blackwell-specific features such as cluster launch control and barrier orderings for TensorCore operations.[^3]
The canonical TPU demonstration of Pallas is Splash Attention, the name the JAX team uses for "Sparse Flash Attention," a general-purpose attention kernel where users can specify an attention mask using NumPy and the kernel skips fully masked tiles automatically.[^14][^15] The implementation lives in the JAX repository at jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py and is released under the Apache 2.0 license with a "Copyright 2023 The JAX Authors" header. Its module docstring describes the file as the "Implementation of Sparse Flash Attention, a.k.a. 'Splash' attention."[^15]
Beyond Splash Attention, the same directory contains a more traditional Pallas TPU flash_attention.py that implements block-wise attention with custom VJP rules for the backward pass, configurable BlockSizes, optional segment-id masking, and causal masking.[^16] Together these kernels make Pallas the default attention path on TPUs in Google's open MaxText training stack, where Splash Attention serves as the default attention kernel for training Transformer models such as DeepSeek, Gemma, and LLaMA derivatives.[^17]
The kernel exploits Pallas features that map well to TPU hardware. A two-dimensional grid (num_q_blocks, num_kv_blocks) gives Mosaic full visibility into the iteration pattern so that fully causally-masked tiles are skipped, and double-buffered DMA lets the matrix-multiply unit work on the current tile while the DMA engine fetches the next, hiding HBM latency behind compute.[^17][^18] On GPUs there is a separate Pallas implementation under jax/experimental/pallas/ops/gpu/attention.py, and an external jax-flash-attn2 package wraps both Pallas and Triton-based attention kernels for JAX models.[^19][^20]
Pallas is intentionally a thin programming-language layer that delegates code generation to two separate backends, both of which are useful systems in their own right.
Mosaic is Google's MLIR-based code generator for TPU. Pallas is not Mosaic, but on TPU Pallas is the front end users typically interact with, and Mosaic is the compiler that turns the lowered MLIR into TPU instructions. The OpenXLA Dev Lab 2024 post described both as "novel kernel programming languages, Pallas and Mosaic, which empower developers to write highly optimized code for specialized hardware."[^21] On Mosaic's TPU side, the Pallas TPU details page documents that the resulting code uses VMEM/SMEM for working data and lets the compiler manage HBM transfers in the background.[^7]
Triton, by contrast, is OpenAI's GPU-focused compiler. Pallas's GPU backend originally consisted of a translator from JAX primitives to Triton IR; the design note explains the relationship by saying Triton "demonstrates that an array programming language can be practical for writing GPU kernels and JAX is just that," and that JAX additionally provides "a flexible front-end for compilers and program transformations."[^2] The newer Mosaic GPU backend is described in the Pallas GPU reference as similar in spirit to Triton but more explicit, particularly with regard to pipelining and synchronization, in exchange for tighter control over the generated code.[^3]
The relationship between Pallas and these two backends is therefore asymmetric: Pallas is a user-facing language with two backends, while Mosaic (TPU and GPU) and Triton are compiler infrastructure that other tools, including Tokamax, can also target.[^22]
The MaxText reference training stack documents Pallas as the way it ships fused operators for production use, naming two concrete examples: a Pallas kernel for Mixture-of-Experts block-sparse matrix multiplication, and the Splash Attention kernel that is enabled by default for fused attention.[^17] These kernels are used in MaxText's open-source training recipes for DeepSeek, Gemma, and LLaMA models.[^17] Google's developer and cloud blogs identify the JAX AI Stack (JAX, XLA, Pallas, MaxText) as the production training pipeline for their flagship models, and Google publicly stated that Gemini 3 was trained "entirely using JAX" on TPUs.[^23][^24] TPU Ironwood is described by Google Cloud as the codesigned hardware target of the JAX/XLA stack that backs Gemini and related models.[^8]
Pallas has been ported to other ML frameworks. PyTorch/XLA 2.4 (released 2024-07-31) added Python-side bindings that allow PyTorch users to call Pallas kernels (Flash Attention, Paged Attention, MegaBlocks) on Cloud TPU through PyTorch's autograd machinery.[^10] OpenXLA's Tokamax project, described as "a GPU and TPU kernel library," builds on top of Pallas to offer autotuned implementations of common attention and matmul kernels with multiple backends per kernel (XLA, Mosaic-TPU, Mosaic-GPU, Triton), letting users either pick a backend explicitly or accept the library's autotuned default.[^22] The vLLM inference engine added a unified TPU backend that exercises the JAX/Pallas path on TPU, and Keras documents a "Define a Custom TPU/GPU Kernel" guide that wraps Pallas as its mechanism for user kernels.[^25]
The official documentation begins with a prominent warning that "Pallas is experimental and is changing frequently. Expect to encounter errors and unimplemented cases," for example when lowering high-level JAX concepts that require emulation or unfinished features.[^1] Specific limitations called out in the TPU details page include block shapes whose ranks must be at least 1 with the last dimensions divisible by 8 and 128, no support for jnp.int4 or integer reductions, loop primitives that are "fully unrolled during compilation" (so trip counts must be small), and operations such as sin/cos being relatively expensive on TPU.[^7]
On the GPU side, the Mosaic GPU backend only runs on Hopper and newer GPUs; the Triton backend supports older Ampere hardware but is maintained "on a best-effort basis."[^3][^7] Practical bugs reported in the JAX issue tracker include cases such as Pallas FlashAttention failing with bfloat16 activations on TPU v3 (the implementation predates official v3 support) and forward GPU FlashAttention only working when sequence lengths are multiples of the block size.[^26][^27] The documentation also notes that pallas_call does not yet expose every JAX transformation: checkify and custom_partitioning integration are listed as future work in the design note.[^2]
A separate ergonomic limitation is that the explicit memory model adds non-trivial cost to learning: as the design note puts it, Pallas "operates at a lower level of abstraction" and "requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator," in contrast to ordinary JAX programs that treat the device as a single virtual array machine.[^1][^2]
| System | Languages | GPU support | TPU support | Notes |
|---|---|---|---|---|
| Pallas | Python (JAX traces) | Yes, via Triton IR or Mosaic GPU[^3] | Yes, via Mosaic (TPU)[^2] | Multi-target; integrates with jit, vmap, grad[^2] |
| Triton (compiler) | Python (AST-parsed DSL) | Yes (NVIDIA, AMD experimental) | No | Originated at OpenAI; used as torch.compile's GPU backend[^28] |
XLA CustomCall | C++ plus CUDA | Yes | Yes | Requires lower-level expertise; predates Pallas[^2] |
| Tokamax | Python on top of Pallas, Triton, Mosaic-GPU, Mosaic-TPU, XLA | Yes | Yes | Multi-backend library with autotuning[^22] |
The recurring theme in published comparisons is that Pallas's distinguishing feature is its TPU support and its native fit into the JAX ecosystem, while Triton remains the more mature GPU-focused alternative. The Pallas design note frames the choice as complementary rather than competitive: Pallas borrows Triton's array-programming model on GPU while adding a TPU path that the Triton compiler does not target.[^2][^28]
Pallas resolves a long-standing asymmetry in the deep learning kernel ecosystem. Before Pallas, custom-kernel development on NVIDIA GPUs had matured through CUDA, Triton, and downstream frameworks such as torch.compile, but TPUs lacked a comparable user-accessible kernel language outside Google's internal toolchain. The JAX team's design note frames the absence of such tools as a meaningful disadvantage for TPUs in production settings, since models requiring custom operators that were either unsupported by the framework or implemented suboptimally would underperform on TPU relative to GPU.[^2] By giving the same Python source program a path to both Mosaic (TPU) and Triton/Mosaic GPU (GPU), Pallas makes it possible to develop a single attention or Mixture-of-Experts kernel that runs on either family of accelerator.
The Splash Attention/MaxText combination demonstrates this in practice: a single Pallas-based attention kernel serves as the default fused attention used in production-grade training of large models on Google's TPU fleet, while equivalent Pallas GPU kernels target H100/Blackwell hardware.[^17][^3] The integration of Pallas into PyTorch/XLA further extends its reach beyond JAX users.[^10]
The name "Pallas" is reused in unrelated software projects, including robotics frameworks and Solidity tools. This article concerns only the JAX kernel language documented at docs.jax.dev/en/latest/pallas and developed in the jax-ml/jax GitHub repository.[^1][^4]