# Pallas (JAX kernel language)

> Source: https://aiwiki.ai/wiki/jax_pallas
> Updated: 2026-06-07
> Categories: AI Infrastructure, Developer Tools, Google
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

# Pallas (JAX kernel language)

**Pallas** is an experimental extension to [JAX](/wiki/jax) that lets users write custom hardware kernels in Python and lower them to both [Tensor Processing Units](/wiki/tpu) and NVIDIA [GPUs](/wiki/gpu) from a single source.[^1] It exposes a lower level of abstraction than ordinary JAX programs (Refs into on-chip memory, explicit grids, block specifications) while preserving JAX's tracing, `jax.numpy` API, and composable transformations such as `jit`, `vmap`, and `grad`.[^1][^2] On GPUs Pallas lowers to either the [Triton](/wiki/triton) intermediate representation or to Mosaic GPU; on TPUs it lowers to Mosaic, Google's MLIR-based TPU code generator.[^2][^3] The project is developed inside the [JAX](/wiki/jax) repository (`jax-ml/jax`) and is documented as an experimental, frequently changing component.[^1][^4]

## Background and motivation

Modern deep learning compilers such as [XLA](/wiki/xla) do a good job of fusing and scheduling common array operations, but production teams routinely hit cases where the compiler cannot match the throughput of a hand-written kernel. The JAX project's Pallas design note frames the problem this way: while XLA "does a good job compiling user programs," some users "inevitably hit XLA's limitations" and need an "escape hatch" that allows experts to write hand-tuned kernels that outperform XLA at that point in time.[^2] XLA already exposes a `CustomCall` mechanism, but, as the same design note observes, it requires writing C++ and, on GPU, learning the CUDA programming model, both of which are mismatched with the array-oriented Python style most JAX users prefer.[^2]

On the GPU side, OpenAI's [Triton](/wiki/triton) showed that a Python embedded DSL with an array programming model could produce highly competitive kernels without requiring developers to manage individual threads. On the TPU side, however, there was no equivalent option: writing custom TPU kernels required deep familiarity with Google's internal compiler stack. The JAX team's stated goal for Pallas is to give JAX a Triton-class kernel front end with first-class TPU support, by reusing JAX's tracing machinery and `jax.numpy` API as the source language and routing kernels through hardware-specific backends.[^2] Sharad Vikram, an engineer on the JAX team at Google DeepMind, is credited as the creator and lead of Pallas and gave the first public talk titled "Pallas: A JAX Kernel Language" in September 2023.[^5]

| Attribute | Value |
| --- | --- |
| Project name | Pallas (`jax.experimental.pallas`) |
| Parent project | [JAX](/wiki/jax) (`jax-ml/jax` GitHub repository) |
| Lead developer | Sharad Vikram, Google DeepMind[^5] |
| First public availability | JAX `main` branch by August 2023; included in JAX releases shortly thereafter[^4][^6] |
| GPU backends | Triton IR and Mosaic GPU[^1][^3] |
| TPU backend | Mosaic (TPU code generator)[^2][^3] |
| Supported TPU hardware | TPU v4 and later, including [Trillium (TPU v6e)](/wiki/google_trillium) and [TPU Ironwood](/wiki/tpu_ironwood)[^7][^8] |
| Supported GPU hardware | Ampere and newer for Triton backend; Hopper and Blackwell for Mosaic GPU[^7][^3] |
| License | Apache License 2.0[^4] |
| Status | Experimental; changing frequently per official documentation[^1] |

## History and release timeline

Pallas was developed inside the `google/jax` repository through 2023 and merged in increments. A community discussion in August 2023 noted that Pallas "does not seem to be in 0.4.14" but was already present on JAX's main branch, with maintainers replying that Pallas had become "part of JAX," though stable installation required additional dependencies such as JAX-Triton.[^6] Sharad Vikram's public talk "Pallas: A JAX Kernel Language" followed in September 2023.[^5]

The JAX core repository moved from `github.com/google/jax` to `github.com/jax-ml/jax` in September 2024, taking Pallas with it. The JAX team explained the move by noting that the `jax-ml` organization had already housed related projects such as `ml_dtypes` and `jax-triton`, and that consolidating the canonical repository there better reflected ownership by the JAX core team rather than Google as a whole.[^9]

A separate Pallas changelog is maintained alongside the main JAX changelog. Notable milestones recorded there include the addition of TPU multi-buffering for pipelines, scalar prefetch and block-sparse computation, GPU atomic primitives such as `atomic_add` and `atomic_max`, the introduction of `Mosaic GPU` as the default lowering path on GPU (with Triton retained as an opt-in backend), and a refactor of `pl.kernel` to use `out_type`/`scratch_types` in JAX 0.10.0.[^4]

Outside of JAX itself, Pallas was integrated into PyTorch/XLA. The Google Cloud blog post announcing PyTorch/XLA 2.4 (published 2024-07-31) describes Pallas as "a custom kernel language originally developed for JAX" that "supports both TPUs and GPUs" and is "similar to the Triton library, but since it also works on both TPUs and GPUs, it's easier to port your model from one machine learning accelerator to the other."[^10] That release shipped PyTorch-side wrappers around Pallas implementations of Flash Attention, Paged Attention, and the MegaBlocks block-sparse Mixture-of-Experts kernel, with full autograd integration.[^10]

## How Pallas works

### Programming model

A Pallas kernel is an ordinary Python function that operates on **Refs**, mutable references to on-chip scratchpad memory rather than the immutable JAX arrays that ordinary `jax.numpy` code manipulates.[^1][^2] A trivial vector-addition kernel from the official quickstart looks like:[^11]

```python
def add_vectors_kernel(x_ref, y_ref, o_ref):
    x, y = x_ref[...], y_ref[...]
    o_ref[...] = x + y
```

The kernel reads from input refs, computes with normal JAX operations, and writes to an output ref. It does not return anything because results are produced through side effects on the output ref.[^11]

To call a kernel inside a regular JAX program, users invoke `pl.pallas_call`, a higher-order operation analogous to `jax.jit` or `jax.pmap` that lifts the kernel function into the surrounding JAX trace:[^11][^12]

```python
@jax.jit
def add_vectors(x, y):
    return pl.pallas_call(
        add_vectors_kernel,
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
    )(x, y)
```

### Grids and BlockSpecs

`pallas_call` accepts a **grid**, a tuple specifying an iteration space. A grid of `(4, 5)` launches twenty program instances, and inside the kernel `pl.program_id(axis=...)` returns the current grid coordinate. On GPUs program instances run in parallel; on TPUs they execute in lexicographic order by default, with optional `dimension_semantics` annotations to mark dimensions as parallel on multi-core TPUs.[^11][^7]

Because real kernels usually operate on tiles of large arrays rather than whole arrays at once, Pallas introduces the **BlockSpec**, an abstraction that combines a `block_shape` with an `index_map` from program ids to block coordinates. A two-dimensional matrix multiplication kernel can specify block shapes for X and Y and let the runtime "carve up" inputs and outputs automatically:[^11][^13]

```python
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0))
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
```

The Pallas design note describes BlockSpec as the abstraction that "tells Pallas how to map a program id to a slice/block of an input or output," with the grid plus BlockSpecs together describing how to tile inputs while Pallas schedules the tiles across hardware and "handles the DMA transfers automatically." Inputs entering a kernel body are already staged into the on-chip scratchpad (VMEM on TPU, shared memory on GPU) rather than living in HBM/DRAM.[^2][^13]

### Memory hierarchy and pipelining

Pallas exposes the memory hierarchy of accelerators explicitly. On TPUs the kernel body operates on VMEM (16 MB and up on recent generations) or SMEM, while communication with high-bandwidth memory is "handled by the compiler and overlapped with compute."[^7] When two consecutive grid indices use the same input slice, the HBM transfer for the second iteration is skipped because the data is already available in VMEM.[^7] On Mosaic GPU, Pallas requires explicit pipelining: the documentation notes that "pipelining in Pallas is programmed explicitly," which is a significant difference from Triton where pipelining is an automatic compiler optimization.[^3]

Vector register constraints on TPU also leak into the kernel programmer's view: vector operations use 8x128 tiles for 32-bit values, the last two axes of an array are treated differently from earlier axes, and block shapes generally require their last two dimensions to be divisible by 8 and 128.[^7]

### Integration with JAX transformations

Because `pallas_call` is just another JAX higher-order primitive, it composes with the rest of the JAX ecosystem. The design note states that a `pallas_call` is "augmented to have an extra grid dimension" when vmapped and that the BlockSpecs are transformed to index along the new batch dimension automatically.[^2] `jax.grad` is also supported by decomposing into the standard `jvp`, `partial_eval`, and `transpose` rules, and `jax.jit` integration is seamless because Pallas kernels are traced and staged out like any other JAX computation.[^2] An emulation mode lowers Pallas to StableHLO so that kernels can be debugged on CPU before being targeted at real accelerators.[^2]

## Backends

### Mosaic for TPU

On TPUs Pallas lowers to **Mosaic**, an MLIR-based code generator that consumes mostly standard MLIR dialects (chiefly `vector` and `arith`) and emits the low-level "LLO" representation that JAX's TPU stack feeds to the underlying hardware. The JAX documentation describes the compilation path explicitly: "After users express their Pallas kernels, we lower them to different representations depending on the target backend. On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to Mosaic."[^2] Pallas's TPU support targets v4 and newer, and recent TPU generations including [Trillium (TPU v6e)](/wiki/google_trillium) and [TPU Ironwood](/wiki/tpu_ironwood) are part of Google's codesigned AI stack that the JAX/Pallas/XLA toolchain runs on.[^7][^8]

### Triton IR for GPU

The original Pallas GPU backend lowered to Triton's intermediate representation. The design note observes that lowering Pallas to Triton was "straightforward because Pallas was designed with Triton as a target language in mind," and that the main translation differences are the absence of BlockSpecs in Triton and Triton's use of pointers rather than indices for memory access.[^2] The Triton backend supports NVIDIA GPUs from Ampere (compute capability 8.0) and above, though the documentation notes it is now maintained on a best-effort basis.[^7][^3]

### Mosaic GPU

In parallel with the TPU effort, the Pallas team built **Mosaic GPU**, a lower-level GPU backend that targets Hopper-class hardware ([NVIDIA H100](/wiki/nvidia_h100)) and the newer [NVIDIA Blackwell](/wiki/nvidia_blackwell) generation directly.[^3] The official reference describes Mosaic GPU as "very similar to a programming model popularized by Triton" but "more low level, which usually means you will have to put in more work, but it also puts you more in control." Pallas continues to support Triton as an alternative, with the per-call `backend` argument letting users pick whichever lowering best fits their needs.[^3] As of JAX 0.9.0 the default GPU lowering path moved to Mosaic GPU, with Triton kept available behind an explicit flag.[^4] Mosaic GPU exposes hardware-specific primitives such as `wgmma` on Hopper, `tcgen05_mma` for Blackwell tensor memory, and TMA-based asynchronous load/store, along with Blackwell-specific features such as cluster launch control and barrier orderings for TensorCore operations.[^3]

## Flagship example: Splash Attention

The canonical TPU demonstration of Pallas is **Splash Attention**, the name the JAX team uses for "Sparse Flash Attention," a general-purpose attention kernel where users can specify an attention mask using NumPy and the kernel skips fully masked tiles automatically.[^14][^15] The implementation lives in the JAX repository at `jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py` and is released under the Apache 2.0 license with a "Copyright 2023 The JAX Authors" header. Its module docstring describes the file as the "Implementation of Sparse Flash Attention, a.k.a. 'Splash' attention."[^15]

Beyond Splash Attention, the same directory contains a more traditional Pallas TPU `flash_attention.py` that implements block-wise attention with custom VJP rules for the backward pass, configurable `BlockSizes`, optional segment-id masking, and causal masking.[^16] Together these kernels make Pallas the default attention path on TPUs in Google's open MaxText training stack, where Splash Attention serves as the default attention kernel for training Transformer models such as DeepSeek, [Gemma](/wiki/gemma), and [LLaMA](/wiki/llama) derivatives.[^17]

The kernel exploits Pallas features that map well to TPU hardware. A two-dimensional grid `(num_q_blocks, num_kv_blocks)` gives Mosaic full visibility into the iteration pattern so that fully causally-masked tiles are skipped, and double-buffered DMA lets the matrix-multiply unit work on the current tile while the DMA engine fetches the next, hiding HBM latency behind compute.[^17][^18] On GPUs there is a separate Pallas implementation under `jax/experimental/pallas/ops/gpu/attention.py`, and an external `jax-flash-attn2` package wraps both Pallas and Triton-based attention kernels for JAX models.[^19][^20]

## Relationship to Mosaic and Triton

Pallas is intentionally a thin programming-language layer that delegates code generation to two separate backends, both of which are useful systems in their own right.

**Mosaic** is Google's MLIR-based code generator for TPU. Pallas is not Mosaic, but on TPU Pallas is the front end users typically interact with, and Mosaic is the compiler that turns the lowered MLIR into TPU instructions. The OpenXLA Dev Lab 2024 post described both as "novel kernel programming languages, Pallas and Mosaic, which empower developers to write highly optimized code for specialized hardware."[^21] On Mosaic's TPU side, the Pallas TPU details page documents that the resulting code uses VMEM/SMEM for working data and lets the compiler manage HBM transfers in the background.[^7]

**Triton**, by contrast, is OpenAI's GPU-focused compiler. Pallas's GPU backend originally consisted of a translator from JAX primitives to Triton IR; the design note explains the relationship by saying Triton "demonstrates that an array programming language can be practical for writing GPU kernels and JAX is just that," and that JAX additionally provides "a flexible front-end for compilers and program transformations."[^2] The newer Mosaic GPU backend is described in the Pallas GPU reference as similar in spirit to Triton but more explicit, particularly with regard to pipelining and synchronization, in exchange for tighter control over the generated code.[^3]

The relationship between Pallas and these two backends is therefore asymmetric: Pallas is a user-facing language with two backends, while Mosaic (TPU and GPU) and Triton are compiler infrastructure that other tools, including Tokamax, can also target.[^22]

## Adoption and production use

### Inside Google

The MaxText reference training stack documents Pallas as the way it ships fused operators for production use, naming two concrete examples: a Pallas kernel for Mixture-of-Experts block-sparse matrix multiplication, and the Splash Attention kernel that is enabled by default for fused attention.[^17] These kernels are used in MaxText's open-source training recipes for DeepSeek, [Gemma](/wiki/gemma), and [LLaMA](/wiki/llama) models.[^17] Google's developer and cloud blogs identify the JAX AI Stack (JAX, XLA, Pallas, MaxText) as the production training pipeline for their flagship models, and Google publicly stated that [Gemini 3](/wiki/gemini_3) was trained "entirely using JAX" on TPUs.[^23][^24] [TPU Ironwood](/wiki/tpu_ironwood) is described by Google Cloud as the codesigned hardware target of the JAX/XLA stack that backs Gemini and related models.[^8]

### Outside of JAX

Pallas has been ported to other ML frameworks. PyTorch/XLA 2.4 (released 2024-07-31) added Python-side bindings that allow PyTorch users to call Pallas kernels (Flash Attention, Paged Attention, MegaBlocks) on Cloud TPU through PyTorch's autograd machinery.[^10] OpenXLA's **Tokamax** project, described as "a GPU and TPU kernel library," builds on top of Pallas to offer autotuned implementations of common attention and matmul kernels with multiple backends per kernel (XLA, Mosaic-TPU, Mosaic-GPU, Triton), letting users either pick a backend explicitly or accept the library's autotuned default.[^22] The vLLM inference engine added a unified TPU backend that exercises the JAX/Pallas path on TPU, and Keras documents a "Define a Custom TPU/GPU Kernel" guide that wraps Pallas as its mechanism for user kernels.[^25]

## Limitations

The official documentation begins with a prominent warning that "Pallas is experimental and is changing frequently. Expect to encounter errors and unimplemented cases," for example when lowering high-level JAX concepts that require emulation or unfinished features.[^1] Specific limitations called out in the TPU details page include block shapes whose ranks must be at least 1 with the last dimensions divisible by 8 and 128, no support for `jnp.int4` or integer reductions, loop primitives that are "fully unrolled during compilation" (so trip counts must be small), and operations such as `sin`/`cos` being relatively expensive on TPU.[^7]

On the GPU side, the Mosaic GPU backend only runs on Hopper and newer GPUs; the Triton backend supports older Ampere hardware but is maintained "on a best-effort basis."[^3][^7] Practical bugs reported in the JAX issue tracker include cases such as Pallas FlashAttention failing with bfloat16 activations on TPU v3 (the implementation predates official v3 support) and forward GPU FlashAttention only working when sequence lengths are multiples of the block size.[^26][^27] The documentation also notes that `pallas_call` does not yet expose every JAX transformation: `checkify` and `custom_partitioning` integration are listed as future work in the design note.[^2]

A separate ergonomic limitation is that the explicit memory model adds non-trivial cost to learning: as the design note puts it, Pallas "operates at a lower level of abstraction" and "requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator," in contrast to ordinary JAX programs that treat the device as a single virtual array machine.[^1][^2]

## Comparison to related kernel languages

| System | Languages | GPU support | TPU support | Notes |
| --- | --- | --- | --- | --- |
| Pallas | Python (JAX traces) | Yes, via [Triton](/wiki/triton) IR or Mosaic GPU[^3] | Yes, via Mosaic (TPU)[^2] | Multi-target; integrates with `jit`, `vmap`, `grad`[^2] |
| [Triton (compiler)](/wiki/triton) | Python (AST-parsed DSL) | Yes (NVIDIA, AMD experimental) | No | Originated at OpenAI; used as torch.compile's GPU backend[^28] |
| XLA `CustomCall` | C++ plus CUDA | Yes | Yes | Requires lower-level expertise; predates Pallas[^2] |
| Tokamax | Python on top of Pallas, Triton, Mosaic-GPU, Mosaic-TPU, XLA | Yes | Yes | Multi-backend library with autotuning[^22] |

The recurring theme in published comparisons is that Pallas's distinguishing feature is its TPU support and its native fit into the JAX ecosystem, while Triton remains the more mature GPU-focused alternative. The Pallas design note frames the choice as complementary rather than competitive: Pallas borrows Triton's array-programming model on GPU while adding a TPU path that the Triton compiler does not target.[^2][^28]

## Significance

Pallas resolves a long-standing asymmetry in the deep learning kernel ecosystem. Before Pallas, custom-kernel development on NVIDIA GPUs had matured through CUDA, [Triton](/wiki/triton), and downstream frameworks such as torch.compile, but TPUs lacked a comparable user-accessible kernel language outside Google's internal toolchain. The JAX team's design note frames the absence of such tools as a meaningful disadvantage for TPUs in production settings, since models requiring custom operators that were either unsupported by the framework or implemented suboptimally would underperform on TPU relative to GPU.[^2] By giving the same Python source program a path to both Mosaic (TPU) and Triton/Mosaic GPU (GPU), Pallas makes it possible to develop a single attention or Mixture-of-Experts kernel that runs on either family of accelerator.

The Splash Attention/MaxText combination demonstrates this in practice: a single Pallas-based attention kernel serves as the default fused attention used in production-grade training of large models on Google's TPU fleet, while equivalent Pallas GPU kernels target [H100](/wiki/nvidia_h100)/[Blackwell](/wiki/nvidia_blackwell) hardware.[^17][^3] The integration of Pallas into PyTorch/XLA further extends its reach beyond JAX users.[^10]

## Distinct from other "Pallas" projects

The name "Pallas" is reused in unrelated software projects, including robotics frameworks and Solidity tools. This article concerns only the JAX kernel language documented at `docs.jax.dev/en/latest/pallas` and developed in the `jax-ml/jax` GitHub repository.[^1][^4]

## See also

- [JAX](/wiki/jax)
- [XLA (Accelerated Linear Algebra)](/wiki/xla)
- [Triton (compiler)](/wiki/triton)
- [FlashAttention](/wiki/flashattention)
- [Flash Attention 3](/wiki/flash_attention_3)
- [Tensor Processing Unit (TPU)](/wiki/tensor_processing_unit_tpu)
- [Trillium (TPU v6e)](/wiki/google_trillium)
- [TPU Ironwood](/wiki/tpu_ironwood)
- [NVIDIA Hopper](/wiki/nvidia_hopper)
- [NVIDIA Blackwell](/wiki/nvidia_blackwell)
- [Gemini 3](/wiki/gemini_3)
- [Google DeepMind](/wiki/google_deepmind)
- [Automatic Differentiation](/wiki/automatic_differentiation)
- [PagedAttention](/wiki/paged_attention)

## References

[^1]: The JAX Authors, "Pallas: a JAX kernel language", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/index.html. Accessed 2026-05-21.
[^2]: The JAX Authors, "Pallas Design", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/design/design.html. Accessed 2026-05-21.
[^3]: The JAX Authors, "Writing Mosaic GPU kernels with Pallas", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/gpu/reference.html. Accessed 2026-05-21.
[^4]: The JAX Authors, "Pallas Changelog", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/CHANGELOG.html. Accessed 2026-05-21.
[^5]: Sharad Vikram, "Sharad Vikram (CV)", sharadvikram.com, 2025. https://sharadvikram.com/pdf/cv-sharadvikram.pdf. Accessed 2026-05-21.
[^6]: JAX maintainers, "How to get started with Pallas? Discussion #17367", GitHub jax-ml/jax, 2023-08-31. https://github.com/jax-ml/jax/discussions/17367. Accessed 2026-05-21.
[^7]: The JAX Authors, "Writing TPU kernels with Pallas", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/tpu/details.html. Accessed 2026-05-21.
[^8]: Google Cloud, "Inside the Ironwood TPU codesigned AI stack", Google Cloud Blog, 2025. https://cloud.google.com/blog/products/compute/inside-the-ironwood-tpu-codesigned-ai-stack. Accessed 2026-05-21.
[^9]: JAX team, "Moving our GitHub repository to `jax-ml` (Discussion #23319)", GitHub jax-ml/jax, 2024-09. https://github.com/jax-ml/jax/discussions/23319. Accessed 2026-05-21.
[^10]: Google Cloud, "PyTorch/XLA 2.4 improves Pallas and adds 'eager mode'", Google Cloud Blog, 2024-07-31. https://cloud.google.com/blog/products/ai-machine-learning/pytorch-xla-2-4-improves-pallas-and-adds-eager-mode/. Accessed 2026-05-21.
[^11]: The JAX Authors, "Pallas Quickstart", JAX documentation, 2026. https://docs.jax.dev/en/latest/pallas/quickstart.html. Accessed 2026-05-21.
[^12]: The JAX Authors, "jax.experimental.pallas.pallas_call", JAX documentation, 2026. https://docs.jax.dev/en/latest/_autosummary/jax.experimental.pallas.pallas_call.html. Accessed 2026-05-21.
[^13]: The JAX Authors, "jax.experimental.pallas.BlockSpec", JAX documentation, 2026. https://docs.jax.dev/en/latest/_autosummary/jax.experimental.pallas.BlockSpec.html. Accessed 2026-05-21.
[^14]: JAX maintainers, "Open source 'Splash Attention' (Sparse Flash Attention)", GitHub commit, google/jax 4e6bef6, 2024. https://github.com/google/jax/actions/runs/7414633687. Accessed 2026-05-21.
[^15]: The JAX Authors, "splash_attention_kernel.py", GitHub jax-ml/jax, 2023. https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py. Accessed 2026-05-21.
[^16]: The JAX Authors, "flash_attention.py (Pallas TPU)", GitHub jax-ml/jax, 2023. https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py. Accessed 2026-05-21.
[^17]: MaxText Authors, "JAX, XLA, and Pallas", MaxText documentation, 2026. https://maxtext.readthedocs.io/en/latest/reference/core_concepts/jax_xla_and_pallas.html. Accessed 2026-05-21.
[^18]: MaxText Authors, "Optimizing with Pallas kernels", MaxText documentation, 2026. https://maxtext.readthedocs.io/en/latest/guides/optimization/pallas_kernels_performance.html. Accessed 2026-05-21.
[^19]: The JAX Authors, "attention.py (Pallas GPU)", GitHub jax-ml/jax, 2024. https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/attention.py. Accessed 2026-05-21.
[^20]: jax-flash-attn2 maintainers, "jax-flash-attn2", PyPI, 2025. https://pypi.org/project/jax-flash-attn2/. Accessed 2026-05-21.
[^21]: Google Open Source, "OpenXLA Dev Lab 2024: Building Groundbreaking ML Systems Together", Google Open Source Blog, 2024-05. https://opensource.googleblog.com/2024/05/openxla-dev-lab-2024-building-grouundbreaking-systems-together.html. Accessed 2026-05-21.
[^22]: OpenXLA, "Tokamax: A GPU and TPU kernel library", GitHub openxla/tokamax, 2025. https://github.com/openxla/tokamax. Accessed 2026-05-21.
[^23]: Google Developers, "Building production AI on Google Cloud TPUs with JAX", Google Developers Blog, 2024. https://developers.googleblog.com/building-production-ai-on-google-cloud-tpus-with-jax/. Accessed 2026-05-21.
[^24]: Into AI, "Google Trained Gemini 3 Entirely Using JAX on Its TPUs: Here Is Why It Matters", intoai.pub, 2025. https://www.intoai.pub/p/google-jax-ai-stack. Accessed 2026-05-21.
[^25]: vLLM team, "vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU", vLLM Blog, 2025-10-16. https://blog.vllm.ai/2025/10/16/vllm-tpu.html. Accessed 2026-05-21.
[^26]: JAX maintainers, "Pallas FlashAttention fails with bfloat16 activations on TPU v3 (Issue #18595)", GitHub jax-ml/jax, 2023. https://github.com/jax-ml/jax/issues/18595. Accessed 2026-05-21.
[^27]: JAX maintainers, "Pallas GPU FlashAttention forward only works when sequence lengths are multiples of block size (Issue #27224)", GitHub jax-ml/jax, 2024. https://github.com/jax-ml/jax/issues/27224. Accessed 2026-05-21.
[^28]: PyTorch/XLA Authors, "Custom Kernels via Pallas", PyTorch/XLA documentation, 2025. https://docs.pytorch.org/xla/master/features/pallas.html. Accessed 2026-05-21.

