Medusa is an inference acceleration framework for large language models that replaces the single sequential token prediction of standard autoregressive decoding with a set of parallel prediction heads attached to the existing model. Rather than deploying a separate smaller draft model to propose candidate tokens -- the strategy used in conventional speculative decoding -- Medusa adds multiple lightweight feed-forward layers, called Medusa heads, directly on top of the base model's final hidden states. Each head predicts a different future token position simultaneously, and a tree-based attention mechanism then verifies candidates in a single forward pass. The result is that multiple tokens are accepted per decoding step on average, cutting total wall-clock generation time by 2.2 to 3.6 times while preserving output quality.
The framework was developed by Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, and Tri Dao, with institutional affiliations spanning Princeton University, the University of Illinois Urbana-Champaign, Carnegie Mellon University, the University of Connecticut, and Together AI. The work appeared first as arXiv preprint 2401.10774 in January 2024 and was subsequently published at the 41st International Conference on Machine Learning (ICML 2024). The official implementation is maintained at github.com/FasterDecoding/Medusa under a Creative Commons BY-4.0 license.
Transformer-based language models generate text one token at a time. At each decoding step the model performs a full forward pass through all its transformer layers, attending to every previously generated token via the key-value cache, and produces a probability distribution over the vocabulary. Exactly one token is sampled or taken greedily, the token is appended to the sequence, and the process repeats. Because each step depends on the output of the previous step, the decoding loop is inherently sequential and cannot be parallelized across the output dimension.
This design creates a pronounced latency bottleneck, particularly on modern GPUs. Large matrix multiplications that dominate transformer inference can saturate GPU arithmetic units on large batch sizes, but at batch size 1 -- the typical configuration for interactive or on-demand generation -- the same operations are heavily memory bandwidth bound. The GPU reads model parameters from DRAM to perform each forward pass, and those reads consume far more time than the arithmetic itself. The result is that the GPU is underutilized: arithmetic throughput is not the limiting factor; bandwidth to model weights is. Because each decoding step takes roughly the same time whether it produces one token or could have produced several, any method that generates multiple tokens per step at negligible additional cost directly reduces latency.
Speculative decoding was the dominant method for exploiting this observation before Medusa. In speculative decoding a small draft model generates a candidate sequence of several tokens in a single pass. Those candidates are then verified by the target model in a single batched forward pass that evaluates all candidate positions in parallel, using rejection sampling to decide how many tokens to accept. The expected number of accepted tokens per verification pass exceeds 1.0 if the draft model's predictions correlate well with the target model's, reducing the total number of target model forward passes required.
Speculative decoding can be highly effective, but it introduces several practical difficulties. The draft model must be architecturally compatible with the target model so that probability distributions can be compared under the rejection sampling criterion. Finding or training an appropriate draft model for a custom fine-tuned target model requires significant effort. The system must run two models concurrently, doubling memory management complexity and complicating multi-GPU or multi-node deployments. Furthermore, the gains from speculative decoding are sensitive to how closely the draft model's output distribution matches the target model's: poorly matched models yield low acceptance rates and may produce no speedup at all. For fine-tuned or domain-adapted models, there is often no ready-made draft model available.
The Medusa paper benchmarks standard speculative decoding on the same Vicuna model family used in its own evaluation and reports speedup factors of 1.47 times on Vicuna-7B, 1.56 times on Vicuna-13B, and 1.60 times on Vicuna-33B. These figures are considerably below what Medusa achieves, motivating the question of whether the separate draft model can be eliminated entirely.
Medusa's central insight is that the base model's final hidden state already contains rich information about what token sequences are likely to follow. The original language model head maps this hidden state to a next-token probability distribution. There is no fundamental barrier to adding additional heads that map the same hidden state to probability distributions over the token two positions ahead, three positions ahead, and so on. These additional heads are structurally independent of each other and of the base model weights; they can be trained while the base model is held frozen, or trained jointly with the base model under a modified training recipe.
Formally, let $h_t$ denote the hidden state produced by the last transformer layer at decoding step $t$. The $k$-th Medusa head computes:
$$p_t^{(k)} = \text{softmax}!\left(W_2^{(k)} \cdot \left(\text{SiLU}\left(W_1^{(k)} \cdot h_t\right) + h_t\right)\right)$$
where $W_1^{(k)} \in \mathbb{R}^{d \times d}$ and $W_2^{(k)} \in \mathbb{R}^{d \times V}$ are the learned parameters of head $k$, $d$ is the model's hidden dimension, and $V$ is the vocabulary size. The residual connection ($+ h_t$) is important for stability: it allows the head to start from a position close to zero correction and gradually learn to refine the base model's implicit belief about future tokens. The base model's original language model head is treated as head 0, predicting the next token, while heads 1 through $K$ predict tokens at offsets 2 through $K{+}1$. Five heads is the configuration found empirically most useful; adding more yields diminishing returns because prediction accuracy for positions further in the future declines.
The next-next-token prediction accuracy of the first additional head (head 1) is approximately 60% top-1 and over 80% top-5 in practice, which is high enough to make tree-based candidate verification highly efficient.
Given the top-$s_k$ predictions from each of the $K$ heads, the candidate set for a single decoding step consists of all possible continuations formed by picking one token per head. With $K = 5$ heads and $s = 2$ top candidates per head, the Cartesian product produces $2^5 = 32$ five-token candidate sequences. Verifying each sequence independently would require 32 forward passes, which would be far slower than ordinary decoding.
Tree attention solves this by recognizing that many candidate sequences share a common prefix. Instead of 32 independent sequences, the candidates can be arranged as a tree in which the root is the current token, the next level contains the top-$s_1$ predictions from head 1, each node at that level branches into the top-$s_2$ predictions from head 2, and so on. Any two paths through the tree that share a prefix node can share the key-value computations for those prefix tokens. A single forward pass processes all unique tree nodes simultaneously by using a custom attention mask that restricts each node's attention to only its direct ancestors in the tree, rather than to all other nodes in the batch.
This modified attention mask is the key technical mechanism. In standard multi-query batched decoding, every token in the batch attends to the full context plus all other batch tokens, which would allow candidates from different branches to incorrectly influence each other. By masking out tokens that are not ancestors of the current node, the tree attention ensures that each candidate position sees exactly the context it would see if that path through the tree were the true generated sequence. The resulting forward pass is semantically equivalent to running $s_1 \times s_2 \times \ldots \times s_K$ separate forward passes but performs only as much compute as a single pass with the number of unique tree nodes.
The tree need not be dense. With five heads and two top candidates per head the full Cartesian product yields 62 tree nodes (including the root). The paper evaluates different tree configurations and finds that a sparse tree of approximately 64 nodes achieves a better trade-off between coverage and computational overhead than either the full dense tree or a very small tree. The tree configuration is treated as a hyperparameter that can be tuned per model and use case. In TensorRT-LLM's implementation the tree structure is a runtime parameter, allowing operators to adjust it without recompiling the model.
Once the tree attention forward pass completes, the system must decide how many of the proposed tokens to accept. Standard speculative decoding uses rejection sampling, which requires comparing the target model's probability of each token against the draft model's probability, accepting with probability equal to the minimum of their ratio and 1. This scheme guarantees that the accepted sequence follows exactly the target model's distribution, which is important for correctness at non-zero sampling temperatures.
Medusa introduces a different scheme called typical acceptance, which is appropriate given that the proposal heads are part of the same model rather than a separate draft model. Typical acceptance accepts a candidate token from position $k$ if its probability under the base model exceeds a threshold that depends on the entropy of the base model's distribution at that position:
$$p^{(0)}(x_{n+k} \mid \ldots) > \min!\left(\epsilon,; \delta \cdot \exp!\left(-H\left(p^{(0)}\right)\right)\right)$$
where $H$ denotes the entropy of the distribution, $\epsilon$ is a hard lower bound, and $\delta$ is a scaling factor. At low entropy (confident predictions) the threshold is tight, meaning only very probable tokens are accepted. At high entropy (uncertain predictions) the threshold is looser, accepting a broader range of plausible tokens. This scheme is inspired by typical sampling from information theory, where tokens near the entropy of the distribution -- rather than only the highest-probability tokens -- are considered typical outputs. The paper reports that typical acceptance provides approximately 10% speedup over greedy acceptance by approving longer candidate sequences, while maintaining generation quality closely comparable to random sampling.
At temperature zero (greedy decoding), typical acceptance reduces to simple comparison against the argmax, and no stochasticity is involved. At higher temperatures, the entropy-adaptive threshold allows more liberal acceptance without requiring rejection sampling against a separate draft model's probabilities.
After acceptance, the longest valid prefix of the candidate tree that satisfies the acceptance criterion is appended to the context, the KV cache is updated, and decoding continues. The average number of new tokens accepted per decoding step, called the acceleration rate, ranges from about 3.0 to 3.5 for Medusa-2 configurations, meaning each forward pass produces three to three-and-a-half tokens on average instead of one.
Medusa offers two distinct training regimes, chosen based on available compute and the desired trade-off between training cost and inference speedup.
In the Medusa-1 regime the base language model's weights are held completely frozen, and only the additional heads are trained. The training objective is a weighted sum of cross-entropy losses across the $K$ additional heads:
$$\mathcal{L}{\text{Medusa-1}} = \sum{k=1}^{K} -\lambda_k \log p_t^{(k)}!\left(y_{t+k+1}\right)$$
where $\lambda_k = 0.8^k$ is a decaying weight that places more emphasis on near-future positions, which are easier to predict and more likely to be accepted. The frozen backbone ensures that the base model's output distribution is completely unaffected; Medusa-1 provides lossless inference acceleration in the sense that the set of possible outputs is unchanged.
Training Medusa-1 requires minimal compute. The reference implementation trains the heads for Vicuna-7B on approximately 60,000 ShareGPT conversation samples in about five hours on a single NVIDIA A100 PCIE GPU. The heads contain only a small fraction of total parameters compared to the base model, so memory requirements are modest and training can be performed on hardware that cannot accommodate full fine-tuning of the base model. Medusa-1 achieves speedup factors of 2.18 times on Vicuna-7B and 2.33 times on Vicuna-13B.
Medusa-2 trains the heads and the base model backbone jointly, using a combined loss that includes both the standard next-token prediction loss and the Medusa head losses:
$$\mathcal{L}{\text{Medusa-2}} = \mathcal{L}{\text{LM}} + \lambda_0 \cdot \mathcal{L}_{\text{Medusa-1}}$$
Joint training allows the backbone to learn internal representations that are more informative for the heads' multi-step predictions, producing higher head accuracy and thus higher acceptance rates. The additional training complexity is managed through two techniques. First, differential learning rates are used: the backbone is updated at a lower rate (around $1 \times 10^{-4}$) while the heads are updated at a rate approximately four times higher ($2 \times 10^{-3}$) to allow the heads to adapt quickly to the evolving backbone representations. Second, a two-stage warmup procedure trains the heads alone for an initial period before enabling backbone updates, preventing the joint optimization from destabilizing head training early in the run.
The LoRA (Low-Rank Adaptation) adapter technique is used for the backbone component of Medusa-2 training to limit memory consumption. A rank-32 LoRA with $\alpha = 16$ and dropout of 0.05 is applied to the backbone, meaning the actual parameter updates are factored into low-rank matrices and do not require storing full-rank gradient tensors. This makes Medusa-2 training feasible on hardware with limited memory, though it remains more demanding than Medusa-1.
Medusa-2 achieves substantially higher speedups: 2.83 times on both Vicuna-7B and Vicuna-13B, 2.66 times on Zephyr-7B-Beta, and 2.35 times on Vicuna-33B. On certain task categories the gains are larger: the paper reports 3.29 times speedup on coding tasks and 3.62 times on extraction tasks with Vicuna-7B under Medusa-2, reflecting higher head prediction accuracy when the output is more structured and predictable.
A practical challenge arises when applying Medusa to a model that has been fine-tuned on proprietary or unavailable data. If the original training data cannot be used to train the heads, standard supervised training is not possible. Medusa addresses this through a self-distillation mechanism: the model itself is used to generate training data from freely available seed prompts (such as ShareGPT or UltraChat). The generated outputs, which reflect the model's own behavior, are used as the supervision signal for head training.
For Medusa-2 self-distillation, the backbone loss is replaced with a KL divergence loss that keeps the backbone's next-token distribution close to the original model's distribution:
$$\mathcal{L}{\text{LM-distill}} = \text{KL}!\left(p{\text{original},t}^{(0)} ,|, p_t^{(0)}\right)$$
This prevents the backbone from drifting away from its original behavior while still allowing it to develop internal representations useful for the heads. Self-distillation makes it practical to add Medusa to any instruction-tuned or RLHF-trained model without access to the original fine-tuning data, which is important for models trained with proprietary instruction datasets or human preference annotations.
The benchmark results in the paper use the MT-bench evaluation suite, which covers eight categories of tasks including writing, reasoning, math, coding, extraction, STEM, humanities, and roleplay. Speedup is measured as wall-clock tokens generated per second relative to vanilla autoregressive decoding at batch size 1. MT-bench quality scores are used to verify that output quality is preserved.
| Model | Medusa-1 speedup | Medusa-2 speedup |
|---|---|---|
| Vicuna-7B (v1.5) | 2.18x | 2.83x |
| Vicuna-13B (v1.5) | 2.33x | 2.83x |
| Vicuna-33B (v1.3) | -- | 2.35x |
| Zephyr-7B-Beta | -- | 2.66x |
| Model | Acceleration rate (tokens/step) | Per-step overhead |
|---|---|---|
| Vicuna-7B (Medusa-2) | 3.47 | 1.22x |
| Zephyr-7B (Medusa-2) | 3.14 | 1.18x |
| Vicuna-13B (Medusa-2) | 3.51 | 1.23x |
| Vicuna-33B (Medusa-2) | 3.01 | 1.27x |
The acceleration rate column reports the average number of new tokens accepted per decoding step (including the base model's own token). The overhead column reports the cost of a Medusa forward pass relative to a standard forward pass. An acceleration rate of 3.47 with 1.22x overhead yields a net speedup of roughly $3.47 / 1.22 \approx 2.84$, consistent with the reported 2.83 times speedup for Vicuna-7B.
Medusa-2 models achieve MT-Bench scores within a small margin of the original models, confirming that the joint training does not meaningfully degrade output quality:
| Model | Original score | Medusa-2 score | Delta |
|---|---|---|---|
| Vicuna-7B | 6.17 | 6.18 | +0.01 |
| Zephyr-7B-Beta | 7.32 | 7.25 | -0.07 |
| Vicuna-13B | 6.57 | 6.43 | -0.14 |
| Vicuna-33B | 7.12 | 7.17 | +0.05 |
| Task category | Speedup |
|---|---|
| Coding | 3.29x |
| Extraction | 3.62x |
| Math | 2.53x |
| Reasoning | 2.71x |
| Writing | 2.68x |
Structured tasks like coding and extraction, where the model's next-token predictions are more deterministic and the Medusa heads achieve higher accuracy, yield the largest speedups. Open-ended creative writing and mathematical reasoning, where token probabilities are more diffuse, yield somewhat lower but still substantial gains.
Medusa significantly outperforms standard speculative decoding (using a separately trained draft model) on the same model family in the paper's evaluation:
| Model | Speculative decoding | Medusa-2 |
|---|---|---|
| Vicuna-7B | 1.47x | 2.83x |
| Vicuna-13B | 1.56x | 2.83x |
| Vicuna-33B | 1.60x | 2.35x |
The comparatively modest gains from speculative decoding in this benchmark reflect the difficulty of finding a well-matched draft model for instruction-tuned Vicuna models. The draft model used in the paper's speculative decoding baseline is a smaller general-purpose language model, and the distribution mismatch between draft and target leads to low acceptance rates. Medusa avoids this problem by deriving its proposals directly from the target model's own hidden states, which are by construction well-correlated with the target model's output distribution.
EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is a subsequent speculative decoding method that, like Medusa, adds prediction machinery to the base model to avoid requiring a separate draft model. EAGLE was published in January 2024 alongside the Medusa paper. Rather than predicting multiple future tokens from the current hidden state, EAGLE trains a small autoregressive draft head that operates on the base model's feature sequence, producing speculative tokens that are contextually conditioned on the preceding hidden states.
Comparative evaluations conducted as part of the EAGLE-2 paper (arXiv 2406.16858, 2024) show EAGLE and its successors achieving higher speedups than Medusa on a shared benchmark:
| Method | Vicuna-7B speedup | Vicuna-13B speedup |
|---|---|---|
| Medusa | 1.91x | 2.07x |
| EAGLE | 2.90x | 3.07x |
| Hydra | 2.69x | 2.88x |
| EAGLE-2 | 3.62x | 4.26x |
| Standard speculative sampling | -- | 1.93x |
These EAGLE-2 benchmark numbers use temperature 0 on the MT-bench dataset and are measured on the same hardware under the same conditions, providing a fair cross-method comparison. The results indicate that EAGLE's autoregressive draft head, by conditioning each speculative token on prior hidden states rather than predicting all positions from a single state snapshot, achieves higher acceptance rates and thus higher practical speedup. Medusa's simpler per-position feed-forward heads are easier to train and integrate but cannot match the richer conditional structure of EAGLE's draft model.
The Medusa-2 numbers in the original Medusa paper (2.83 times on Vicuna-7B) are higher than the EAGLE-2 benchmark numbers for Medusa (1.91 times) because they measure different things: Medusa's own evaluation uses Medusa-2 jointly trained models with typical acceptance, while the EAGLE-2 benchmark uses Medusa in a configuration closer to Medusa-1 with greedy acceptance. The discrepancy illustrates how significantly training regime, acceptance scheme, and evaluation methodology affect reported speedup figures.
Medusa's practical advantage over EAGLE is its simplicity: training requires only a few hours on a single GPU, the implementation is a minimal extension of the base model with no autoregressive draft head training loop, and integration into existing serving frameworks is straightforward. For practitioners who need a working acceleration solution quickly and are not optimizing for maximum possible speedup, Medusa's ease of use is a meaningful benefit.
NVIDIA's TensorRT-LLM library includes native support for Medusa decoding. The implementation provides a convert_checkpoint.py script that loads Medusa head weights alongside base model weights and produces a TensorRT engine that runs the combined forward pass. The Medusa tree configuration is exposed as a runtime parameter (medusa_choices), allowing operators to vary the candidate tree structure without rebuilding the engine. The TensorRT-LLM integration supports greedy decoding (temperature 0) and has been tested with Vicuna-7B, Vicuna-13B, and the quantized Llama-3.1-8B-Medusa-FP8 model, combining Medusa acceleration with FP8 weight quantization for further efficiency. Beam search and certain other decoding strategies are incompatible with the Medusa tree structure and are not supported in this implementation.
HuggingFace's Text Generation Inference server added support for Medusa in its inference path. TGI's integration allows users running Medusa-augmented checkpoints stored on the Huggingface Hub to transparently benefit from tree-based acceleration without code changes beyond specifying the model path. This integration extended Medusa's practical reach to the large community of practitioners using TGI for model deployment.
Ant Group's RTP-LLM inference framework, used in large-scale production deployments, also incorporates Medusa support. RTP-LLM targets high-throughput serving at scale, and the Medusa integration follows the same principle of adding the head-based candidate tree to the serving pipeline.
Together AI, one of the institutional sponsors of the research, deployed Medusa in its inference API to accelerate generation for customers. The Together AI blog noted that the 33B Vicuna model with Medusa acceleration could operate at speeds comparable to the 13B Vicuna baseline, effectively doubling throughput for larger models where memory bandwidth pressure is most acute. Together AI's involvement in the research ensured that the implementation was tested at production scale from early in the project's life.
vLLM supports Medusa as one of several speculative decoding strategies. The vLLM implementation integrates Medusa heads into the scheduler and token generation pipeline alongside draft-model-based speculative decoding and prompt lookup decoding. As of 2024, vLLM's Medusa support is described as preliminary relative to the draft-model pathway, with ongoing kernel optimization work aimed at improving throughput at larger batch sizes where the tree attention mechanism is more complex to schedule efficiently. vLLM uses PagedAttention for KV cache management, and coordinating block allocation for tree-based speculative candidates adds additional complexity compared to single-sequence decoding.
Batch size dependency. Medusa's speedup is most pronounced at batch size 1 (single concurrent request), where the bottleneck is memory bandwidth per token and the cost of the extra forward pass for tree verification is well amortized. At large batch sizes, where the model is already operating near peak arithmetic efficiency, the overhead of verifying multiple candidate sequences does not yield as much net improvement. Most serving applications involve mixed batch sizes, and the benefit diminishes as concurrency increases.
Temperature and acceptance rates. At very low generation temperatures, token predictions are highly concentrated, and Medusa heads achieve high acceptance rates. At high temperatures, where output distributions are flatter and more random, the heads predict incorrectly more often, reducing acceptance rates and shrinking the effective speedup. The typical acceptance scheme partially mitigates this but does not fully recover the speedup at high temperatures that rejection-sampling-based speculative decoding with a well-matched draft model can provide.
EAGLE and successors outperform Medusa. As shown in the comparison section, EAGLE-2 and subsequent methods that use richer draft models have surpassed Medusa's speedup numbers on shared benchmarks. For practitioners whose primary goal is maximum speedup and who are willing to accept greater implementation complexity, EAGLE-family methods are generally preferred as of 2024 and 2025. Medusa occupies a different point on the complexity-speedup trade-off curve: easier to train, integrate, and operate, but not the highest-speedup option available.
Greedy-only support in some implementations. TensorRT-LLM's Medusa integration supports only temperature 0 (greedy decoding) as of its initial release, limiting its applicability to sampling-based generation. This is a constraint of the specific implementation rather than the method itself, but it affects users of that framework.
No gain at long contexts. Medusa's heads predict future tokens based on the current hidden state, which encodes the full context. The quality of those predictions does not systematically degrade with context length, but the overhead of tree attention grows with the number of tree nodes and is fixed regardless of context length. For very short prompts where generation terminates quickly, the training and deployment overhead may not be worth the speedup. Medusa is most valuable for workloads generating hundreds to thousands of tokens per request.
Memory overhead. Each Medusa head adds parameters proportional to the model's hidden dimension times the vocabulary size. For a model with hidden dimension $d = 4096$ and vocabulary size $V = 32000$, each head contributes approximately $4096 \times 32000 \times 2 \approx 250$ million parameters at float16 precision. Five heads add roughly 1.25 billion parameters, equivalent to a small fraction of a large model's total size but not negligible for memory-constrained deployments.
Medusa's publication contributed to a productive period of research into draft-model-free speculative decoding. The Hydra method (2024) extends Medusa's multi-head concept with additional conditioning mechanisms. EAGLE (2024) and EAGLE-2 (2024) replace the per-position feed-forward heads with an autoregressive draft head operating on features from the base model, achieving higher acceptance rates at the cost of a more complex training procedure. EAGLE-3 (2025) further scales this approach with training-time modifications to the base model's feature representations. The Lookahead Decoding method (2024) takes a different approach, using Jacobi iteration to generate candidate continuations from a fixed-point approximation rather than trained heads.
Medusa itself underwent revisions. The v2 paper (arXiv 2401.10774, which reached v3) incorporated the Medusa-2 training recipe, the self-distillation mechanism, the typical acceptance scheme, and expanded evaluation across additional models. These additions addressed several of the limitations identified in the initial preprint.
The framework's code repository, github.com/FasterDecoding/Medusa, served as a reference implementation that was studied and extended by multiple inference framework teams. The FasterDecoding GitHub organization also hosts subsequent related projects.