Add Triton sampling backends alongside FlashInfer#280
Conversation
| @@ -18,14 +18,48 @@ | |||
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||
| # SOFTWARE. | |||
|
|
|||
| """Triton sampling helper kernels.""" | |||
| """TokenSpeed-native Triton sampling kernels. | |||
There was a problem hiding this comment.
Please add vLLM's copyright here
There was a problem hiding this comment.
Added SPDX-FileCopyrightText: Copyright contributors to the vLLM project to the Triton sampling kernel header in commit c3c25815.
a76a3e9 to
c3c2581
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c3c2581533
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| top_k_ok = (top_k <= 0) | (before_count < top_k) | ||
| top_p_ok = (top_p >= 1.0 - 1.0e-6) | (before_prob < top_p * total_prob) | ||
| ok = top_k_ok & top_p_ok |
There was a problem hiding this comment.
Use top-k-renormalized mass for top-p rejection
When a request has both finite top_k and top_p < 1 and falls into the generic rejection path (for example top_k 65–127, mixed batches with such a row, or the full Triton backend), this predicate compares before_prob with top_p * total_prob where total_prob is the full-vocab mass. The direct sampler and build_top_k_top_p_probs_from_logits() used by speculative verify apply top-p after top-k renormalization, so normal sampling and verify can admit different nuclei and corrupt speculative decoding probabilities for those sampling params.
Useful? React with 👍 / 👎.
c3c2581 to
981f940
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 981f9400f2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| block_id = tl.min( | ||
| tl.where((row_vals == selected_value) & mask, cols, 2147483647), |
There was a problem hiding this comment.
Resolve top-k samples by token id, not logit value
When the direct top-k/top-p path is used (top_k <= 64) and two nucleus candidates have identical logits, the Gumbel draw selects a rank but this resolution step only carries the sampled logit value and then returns the lowest vocab id with that value. For uniform rows or tied scores this collapses all probability mass for tied candidates onto the smallest token id instead of sampling among them, so finite-top-k requests can produce biased/incorrect tokens; the non-compact helper has the same value-only resolution pattern.
Useful? React with 👍 / 👎.
981f940 to
6813968
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 681396814e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| finite_top_k = (top_k > 0) & (top_k <= TOP_K_PAD) | ||
| top_p_total_prob = tl.where(finite_top_k, top_k_total_prob, total_prob) | ||
|
|
||
| selected = tl.load(candidate_ids_ptr + row * candidate_row_stride + RETRIES) |
There was a problem hiding this comment.
Avoid falling back to argmax after only four top-p retries
For top-p-only requests (the default top_k=-1 with top_p<1) this generic rejection path draws candidates from the unfiltered distribution and only accepts candidates inside the nucleus; if all four retries miss, selected remains initialized to the argmax here. With low-but-valid top_p values such as 0.1 or 0.2, that miss probability is large ((1-top_p)^4), so sampling frequently collapses to greedy instead of the requested nucleus distribution, unlike the previous exact renormalized sampler.
Useful? React with 👍 / 👎.
6813968 to
28fa045
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 28fa045441
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if not self.config.enable_output_logprobs and num_tokens_per_req == 1: | ||
| variants = ( | ||
| _CUDA_GRAPH_VARIANT_DEFAULT, | ||
| _CUDA_GRAPH_VARIANT_NO_FILTER, | ||
| _CUDA_GRAPH_VARIANT_TOP_K_TOP_P, | ||
| ) | ||
| return variants | ||
| return (_CUDA_GRAPH_VARIANT_DEFAULT,) |
There was a problem hiding this comment.
Capture seeded sampler variants when logprobs are enabled
When --enable-output-logprobs is set, this returns only the default graph, so _graph_key_for_replay() falls back to the generic rejection graph for requests whose prepared mode is no_filter or top_k_top_p. Those same seeded requests use the direct Gumbel/top-k path in eager or uncaptured batches, while the rejection path derives different random streams (e.g. offset * RETRIES + retry instead of offset), so the same prompt and seed can produce different tokens depending only on whether the batch replays a captured graph. Either capture the same variants with logprobs or force replay/eager through the same sampler mode.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7179e82b15
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if num_tokens_per_req != 1: | ||
| return _TRITON_GUMBEL_MODE_NONE |
There was a problem hiding this comment.
Handle mixed speculative prefill batches
When speculative decoding is enabled, ModelExecutor.prepare_step() passes self.config.output_length (the draft-token count, usually >1) for the whole batch, even if the batch also contains prefill rows; _run_sampling() then calls sample() for those prefill rows before verify() for the decode rows. For any non-greedy sampling params in such a mixed prefill+decode step, this branch leaves _step_triton_gumbel_mode as NONE, so the subsequent sample() path raises RuntimeError("did not select a sampling mode") instead of sampling the newly admitted request. The mode selection needs to keep a valid single-token sample mode for mixed speculative batches, or sample() must resolve its mode independently of the verify width.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d81cab179f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| top_k_ok = (top_k <= 0) | (before_count < top_k) | ||
| top_p_ok = (top_p >= 1.0 - 1.0e-6) | (before_prob < top_p * top_p_total_prob) | ||
| min_p_ok = (min_p <= 0.0) | (tl.exp(candidate_logit - max_scaled) >= min_p) | ||
| ok = top_k_ok & top_p_ok & min_p_ok |
There was a problem hiding this comment.
Preserve top-k distribution in rejection sampling
When triton_full serves a finite-top_k request (and the plain triton backend falls into generic rejection for top_k 65–127), candidates are still drawn from the full vocabulary and this predicate only accepts them if they land in the top-k set. With top_p=1 and a typical large vocab, the probability that all four retries miss the top-k set is very high (e.g. uniform logits with top_k=100), so the sampler falls back to the initialized argmax instead of sampling from the requested top-k distribution.
Useful? React with 👍 / 👎.
27b7316 to
288f13f
Compare
dc66686 to
c121645
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9a76e165f6
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| min_p = tl.load(min_p_pool_ptr + pool_idx).to(tl.float32) | ||
| min_p_log_threshold = tl.log(tl.maximum(min_p, 1.0e-20)) |
There was a problem hiding this comment.
Treat min_p=0 rows as unfiltered
When a triton_full batch contains any request with min_p > 0, _full_has_min_p sends the whole batch through the min-p kernels, including rows whose own min_p is 0. For those rows this clamp turns the disabled cutoff into row_max + log(1e-20), so tokens more than about 46 scaled-logit units below the max are silently removed; with low temperatures or large logit_bias/penalties this changes a request that explicitly disabled min-p. Use -inf for min_p <= 0 (the same issue appears in the parallel min-p path below).
Useful? React with 👍 / 👎.
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
Signed-off-by: FlamingoPg <1106310035@qq.com>
633932c to
baab8fc
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: baab8fce85
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| accepted_found = tl.where(accepted, 1, accepted_found) | ||
| attempt += 1 | ||
|
|
||
| token = tl.where(accepted_found != 0, accepted_token, row_argmax) |
There was a problem hiding this comment.
Avoid falling back to greedy after top-p repair misses
For top-p-only requests with low but valid top_p values, this repair path still gives up after _TOP_P_REPAIR_NUM_ATTEMPTS and returns row_argmax when every rejection draw lands outside the nucleus. That happens with probability (1 - top_p)^8 (about 43% for top_p=0.1), so the optimized Triton top-p route frequently emits greedy tokens instead of sampling from the requested nucleus distribution; use an exact nucleus sampler or keep drawing until acceptance rather than defaulting to argmax.
Useful? React with 👍 / 👎.
| self._seed_pool, | ||
| offsets_pool, | ||
| out[:rows], | ||
| min_p_pool=self._min_p_pool, |
There was a problem hiding this comment.
Disable min-p in the generic full route when absent
When a triton_full batch has no min_p requests but falls into the generic mixed route (for example mixing top_k=-1 and finite top_k), _full_has_min_p is false but this still passes _min_p_pool into the generic sampler. Rows with the default min_p=0 then use log(max(0, 1e-20)) as a cutoff, silently filtering tokens more than ~46 scaled-logit units below the max even though min-p is disabled; pass None here unless _full_has_min_p is true.
Useful? React with 👍 / 👎.
Summary
tritonandtriton_fullsampling backends alongside the existingflashinfer/flashinfer_fullprobability-route backends.flashinfer; this PR does not remove FlashInfer sampling or change the default route.tokenspeed-kernelboundary; attention/MoE/quantization FlashInfer paths are untouched.Changes
tritonandtriton_fullruntime backends with separate MRV2-style Gumbel routes.PoolSamplingBackendso FlashInfer and Triton share TokenSpeed request-pool state without Triton inheriting FlashInfer probability/coin state.test_sampling.py, while Triton-specific coverage lives intest_sampling_triton.pyandtest_sampling_triton_full.py.Benchmark
These are the latest focused sampling-path results I could trace from the current branch. Source artifacts:
/tmp/tokenspeed_sampling_mr/focused_sampling_path_bench.csv/tmp/tokenspeed_sampling_mr/focused_sampling_path_bench_151936.csv/tmp/tokenspeed_sampling_mr/current_sampling_ops.csvEnvironment from the benchmark logs:
NVIDIA H100 80GB HBM3. Timing uses CUDA events. All numbers are milliseconds, shown asmedian / p95. Speedup isFlashInfer baseline / Triton, so higher is better.Important scope notes:
flashinfer.sample()core behavior. That route does not callfused_topk_topp_renorm; it usessoftmax -> top_k_top_p_sampling_from_probs.flashinfer_full-style probability behavior. That route does callfused_topk_topp_renormon NVIDIA beforemin_p_sampling_from_probs.current_triton_pool_opis the latest optimized pool-aware Triton op path.current_runtime_sampleis the full TokenSpeed runtime sampling backend call and includes route/backend overhead.A. Normal Sample: Triton vs
flashinfer.sample()Baseline route:
Candidate route:
Current Pool-Aware Triton Op vs Old FlashInfer Core
Full Runtime Sample Call vs Old FlashInfer Core
B. Full/min-p Path: Triton Full vs
flashinfer_fullBaseline route on NVIDIA:
Candidate route:
This table is the focused full/min-p operator comparison from
current_sampling_ops.csv.Benchmark Interpretation
flashinfer.sample()does not usefused_topk_topp_renorm;flashinfer_fulldoes.2.02x-3.24xon no-filter and1.71x-2.12xon finite top-k + top-p except 151936/bs32, which is still1.74x.1.05x.flashinfer_fullroute that includesfused_topk_topp_renorm, Triton full also wins in these focused full/min-p operator rows. The tightest case is 32768/bs1min_p, where median improves slightly while p95 is roughly comparable.Validation
pre-commit run --all-files(passed)python -m pytest test/runtime/test_sampling_backend_pool.py test/runtime/test_sampling_backend_registry.py test/runtime/test_cli_config_compat.py(90 passed,18 warnings)python -m pytest tokenspeed-kernel/test/ops/test_sampling.py tokenspeed-kernel/test/ops/test_sampling_triton.py tokenspeed-kernel/test/ops/test_sampling_triton_full.py(60 passed,18 warnings)Notes
flashinfer/flashinfer_fullremain available, and NVIDIA default remainsflashinfer.