-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
Feature/spec decode draft model #24322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature/spec decode draft model #24322
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for speculative decoding using a draft model. The changes are comprehensive, touching configuration, model loading, scheduling, and the core speculative decoding logic. New tests and benchmark modifications are also included to validate and measure the new feature. The overall implementation appears solid. However, I've identified a critical issue in a refactoring of the bind_kv_cache
utility function, which removes an important safety check and could lead to incorrect behavior for certain model architectures.
@tomasruizt - Thank you for the PR!
|
What is the TP you are using for Qwen3-32B? By default, draft model TP is equal to target model TP. Since Qwen3-1.7B is a small model, running it on high TP might be incurring nccl communication cost. Try setting draft TP to 1. |
I ran the benchmarks with TP=1 and num_draft_tokens=3. So we can rule out TP communication issues. |
This pull request has merge conflicts that must be resolved before it can be |
7de2ae1
to
2e0fb65
Compare
@benchislett fair. I'll factor out the duplicated code in In terms of the extra decode: I'll try to eliminate it and reach out for help if needed. |
Signed-off-by: Tomas Ruiz <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Tomas Ruiz <[email protected]>
e3b85dd
to
86d8040
Compare
Signed-off-by: Tomas Ruiz <[email protected]>
As we discussed in our call, I moved the
Number 3 to 5 will become unnecessary if we implement the optimal prefill we discussed, which would reduce the forward passes by one, and improve runtimes. Nevertheless, it's useful to remember the major factor affecting DraftModel speed is the CUDA graph usage, which is now conveniently a single flag. At the moment the EAGLE file diff looks horrible, I guess because of the combination of extracting a superclass plus introducing flags to the constructor. Let me know if you would prefer to review the EAGLE refactor as a separate PR to main (or as multiple small PRs) 👍 Edit: Managed to get nice git diff for the EAGLE file by minimizing changes. |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
The EAGLE code is frequently changed on main, so it is difficult to move EAGLE code around without painful merge conflicts. |
Signed-off-by: Benjamin Chislett <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
fix next_token_ids issue
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Tomas Ruiz <[email protected]>
Purpose
Enabling draft models for speculative decoding (SD).
E.g.
Qwen3-1.7B
as draft model andQwen3-32B
as target model.This type of SD requires no special trained heads (like EAGLE, or Medusa).
Example usage:
vllm serve \ --model=Qwen/Qwen3-4B \ --speculative-config '{"model": "Qwen/Qwen3-0.6B", "method": "draft_model", "num_speculative_tokens": 3, "max-model-len": 2000}' \ --max-model-len 2000
Get a generation:
Status
Acceptance Length
As suggested by @ekagra-ranjan, I benchmarked acceptance length (AL) with the command below:
The AL values within the Qwen3 family seem good, both with temperatures of 0.0 (greedy) and 1.0.
As a sanity check, I benchmarked LLama-3.2-1B as both target and draft, which had almost perfect AL (3.97/4), suggesting its working as intended.
I have not run the default model
meta-llama/Llama-3.1-8B-Instruct
, because I didn't find a good draft model for it, but feel free to suggest one and I can run the benchmarks.Temperature t=0:
Temperature t=1.0:
Using t=1.0, the AL metric degrades. However, spec-decode with probabilities is not yet implemented, needed for lossless rejection sampling. This is being addressed atm: #20459. After that PR is merged, the AL for non-greedy spec-decode should improve.
All scripts and logs used for the benchmarks can be found in this Google Drive.
Online Throughput Metrics
I measured online throughput metrics using the commands below. Hardware was an RTX PRO 6000 96GB. After making sure the draft model also uses CUDA graph, SD has higher throughput than not using SD. See tables below.
The metrics (lower is better) are:
Batch Size = 1
For Temperature = 0.0:Using SD runtimes and TPOT are shorter by ~50%.
Batch Size = 100
For Temperature = 0.0:For Temperature = 1.0:
This scenario with batch size 100 is a more realistic inference case.
Using SD runtimes and TPOT are shorter.
Profiling
This section was removed, since using CUDA graphs on the draft model significantly improved its speed.
Profiling script
I used the command below to profile the generation process and identify that the draft model was running too slow before.Note: The command uses the
--profile
flag, which I introduce in this PR: #24575Test Plan
The added unit test check the correctness metrics. To run it:
cd tests/v1/e2e/ pytest test_spec_decode.py -k test_draft_model_correctness
EAGLE testing
I tested that the EAGLE implementation stays unaffected the command below
The results are in line with previous measurements like #17504 (comment)
Follow-up Optimizations
next_token_ids
together withtarget_token_ids
in the first forward pass of the draft model. This reduces the number of forward passes needed in each drafting phase by one, speeding up drafting.(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.