Skip to content

specdec_bench: stratify --num_requests#1388

Closed
milesial wants to merge 1 commit intoNVIDIA:mainfrom
milesial:specdec-stratify-and-dp
Closed

specdec_bench: stratify --num_requests#1388
milesial wants to merge 1 commit intoNVIDIA:mainfrom
milesial:specdec-stratify-and-dp

Conversation

@milesial
Copy link
Copy Markdown

@milesial milesial commented May 4, 2026

On SPEEDBench, with dataset sorted by category, A low --num_request would only choose samples from a single category. This PR balances the samples across categories, when a category column is present in the dataset.

Summary by CodeRabbit

  • New Features
    • Added configurable data-parallel support for benchmark model initialization to enable optimized distributed inference.
    • Improved dataset sampling with a deterministic, stratified selection that preserves balanced category representation when downsampling.

@milesial milesial requested a review from a team as a code owner May 4, 2026 21:22
@milesial milesial requested a review from h-guo18 May 4, 2026 21:22
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

📝 Walkthrough

Walkthrough

Adds configurable data-parallel size through the run path into vLLM engine args and implements deterministic stratified sampling for SPEEDBench dataset selection when num_samples is smaller than the dataset.

Changes

Data Parallelism Configuration

Layer / File(s) Summary
Runtime Parameter Propagation
examples/specdec_bench/run.py
Model initialization now passes data_parallel_size=args.data_parallel_size into the selected model_class(...).
Engine Integration
examples/specdec_bench/specdec_bench/models/vllm.py
AsyncEngineArgs construction includes data_parallel_size=kwargs.get("data_parallel_size", 1) so the vLLM engine receives the parallelism config.

Stratified Dataset Sampling

Layer / File(s) Summary
Dependencies
examples/specdec_bench/specdec_bench/datasets/speed.py
Added import math to support sampling math operations.
Sampling Logic
examples/specdec_bench/specdec_bench/datasets/speed.py
Added SPEEDBench._stratified_select(dataset, n) which deterministically selects n rows by grouping indices by category, taking up to ceil(n / num_categories) per category, interleaving round-robin, and truncating to n; falls back to dataset.select(range(n)) when category is missing or has ≤1 category.
Dataset Loading Integration
examples/specdec_bench/specdec_bench/datasets/speed.py
_load_dataset uses _stratified_select only when self.num_samples < len(dataset); otherwise returns the full dataset.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed The pull request modifications do not contain any of the six specified security anti-patterns. All security-sensitive parameters like trust_remote_code are properly exposed as configurable parameters with safe defaults (False), and the new code additions contain no dangerous operations.
Title check ✅ Passed The PR title 'specdec_bench: stratify --num_requests' accurately captures the primary change: implementing stratified sampling for the --num_requests parameter in the SPEEDBench dataset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@milesial milesial force-pushed the specdec-stratify-and-dp branch 2 times, most recently from 9f96a3b to ce2a6ed Compare May 4, 2026 21:24
- speed.py: when --num_requests N is below the dataset size and the
  dataset has a `category` column with >1 distinct categories, take
  ceil(N / num_categories) rows from each category and round-robin
  interleave them so any prefix is balanced. Fixes a sampling bug
  where parquet rows are sorted by category, making `--num_requests
  64` on throughput_8k pull 64 high_entropy prompts and zero from
  low_entropy / mixed.

- run.py + models/vllm.py: forward --data_parallel_size to
  AsyncEngineArgs. The CLI flag was parsed but never passed to the
  engine, so DEP-style recipes only worked by accident (vLLM auto-DP
  from CUDA_VISIBLE_DEVICES). Explicit propagation makes DP=N
  reproducible and works under controlled CUDA_VISIBLE_DEVICES.

Signed-off-by: Alexandre Milesi <milesial@users.noreply.github.com>
@milesial milesial changed the title specdec_bench: stratify --num_requests, wire --data_parallel_size specdec_bench: stratify --num_requests May 4, 2026
@milesial milesial force-pushed the specdec-stratify-and-dp branch from ce2a6ed to cb8d44a Compare May 4, 2026 21:24
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/specdec_bench/run.py`:
- Line 158: The code references args.data_parallel_size in run.py (used when
constructing the launcher via data_parallel_size=args.data_parallel_size) but
the CLI parser never defines that flag; add a CLI argument for
--data_parallel_size to the argparse parser (or the function that builds it) so
args has that attribute at runtime, specifying type=int and a sensible default
(e.g., 1) and a short help string; alternatively ensure the call uses
getattr(args, "data_parallel_size", 1) if you prefer a fallback instead of
adding a parser flag.

In `@examples/specdec_bench/specdec_bench/datasets/speed.py`:
- Around line 765-777: _stratified_select can under-select when interleaved has
fewer than n items; change the function to detect if len(interleaved) < n after
the round-robin build and then fill the remainder by appending additional row
indices until n is reached (e.g., iterate remaining rows in cat_to_rows or
dataset indices skipping already-picked ones, or sample from the union of
leftover rows) before calling dataset.select(interleaved[:n]). Update references
in the function (_stratified_select, per_cat, cat_samples, interleaved) to
ensure duplicates are avoided and selection stops exactly at n.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: ceef28a0-80c7-464c-904e-5c65fbd38c7b

📥 Commits

Reviewing files that changed from the base of the PR and between 06ef935 and 9dbc4fc.

📒 Files selected for processing (3)
  • examples/specdec_bench/run.py
  • examples/specdec_bench/specdec_bench/datasets/speed.py
  • examples/specdec_bench/specdec_bench/models/vllm.py

draft_model_dir=args.draft_model_dir,
speculative_num_steps=args.draft_length,
tensor_parallel_size=args.tp_size,
data_parallel_size=args.data_parallel_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Missing CLI arg causes runtime crash on Line 158.

args.data_parallel_size is used, but --data_parallel_size is not defined in this parser, so CLI execution can fail with AttributeError.

💡 Proposed fix
@@
-        data_parallel_size=args.data_parallel_size,
+        data_parallel_size=getattr(args, "data_parallel_size", 1),
@@
     parser.add_argument(
         "--tp_size", type=int, required=False, default=4, help="Tensor parallel size"
     )
+    parser.add_argument(
+        "--data_parallel_size",
+        type=int,
+        required=False,
+        default=1,
+        help="Data parallel size",
+    )
     parser.add_argument(
         "--ep_size", type=int, required=False, default=2, help="Expert parallel size"
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_parallel_size=args.data_parallel_size,
data_parallel_size=getattr(args, "data_parallel_size", 1),
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/specdec_bench/run.py` at line 158, The code references
args.data_parallel_size in run.py (used when constructing the launcher via
data_parallel_size=args.data_parallel_size) but the CLI parser never defines
that flag; add a CLI argument for --data_parallel_size to the argparse parser
(or the function that builds it) so args has that attribute at runtime,
specifying type=int and a sensible default (e.g., 1) and a short help string;
alternatively ensure the call uses getattr(args, "data_parallel_size", 1) if you
prefer a fallback instead of adding a parser flag.

Comment on lines +765 to +777
per_cat = math.ceil(n / num_cats)
# Take the first ``per_cat`` rows from each category (parquet order
# within a category is treated as the canonical sample order).
cat_samples = [
rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values()
]
# Round-robin interleave so the first N rows are balanced.
interleaved: list[int] = []
for i in range(per_cat):
for samples in cat_samples:
if i < len(samples):
interleaved.append(samples[i])
return dataset.select(interleaved[:n])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

_stratified_select can return fewer than n rows in imbalanced datasets.

When multiple categories are small, capping each at per_cat can make len(interleaved) < n, so dataset.select(interleaved[:n]) under-selects instead of honoring the requested sample count.

💡 Proposed fix
@@
-        cat_samples = [
-            rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values()
-        ]
+        cat_rows = list(cat_to_rows.values())
+        cat_samples = [rows[: min(per_cat, len(rows))] for rows in cat_rows]
@@
         interleaved: list[int] = []
         for i in range(per_cat):
             for samples in cat_samples:
                 if i < len(samples):
                     interleaved.append(samples[i])
-        return dataset.select(interleaved[:n])
+        if len(interleaved) < n:
+            leftovers = [rows[min(per_cat, len(rows)) :] for rows in cat_rows]
+            max_leftover = max((len(rows) for rows in leftovers), default=0)
+            for i in range(max_leftover):
+                for rows in leftovers:
+                    if i < len(rows):
+                        interleaved.append(rows[i])
+                        if len(interleaved) == n:
+                            return dataset.select(interleaved)
+        return dataset.select(interleaved[:n])
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/specdec_bench/specdec_bench/datasets/speed.py` around lines 765 -
777, _stratified_select can under-select when interleaved has fewer than n
items; change the function to detect if len(interleaved) < n after the
round-robin build and then fill the remainder by appending additional row
indices until n is reached (e.g., iterate remaining rows in cat_to_rows or
dataset indices skipping already-picked ones, or sample from the union of
leftover rows) before calling dataset.select(interleaved[:n]). Update references
in the function (_stratified_select, per_cat, cat_samples, interleaved) to
ensure duplicates are avoided and selection stops exactly at n.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
examples/specdec_bench/specdec_bench/datasets/speed.py (1)

768-777: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

_stratified_select can under-select and return fewer than n rows.
When some categories have fewer than per_cat rows, interleaved may stay shorter than n; then Line 777 returns too few samples (e.g., highly imbalanced category sizes). This breaks requested --num_requests semantics.

Proposed minimal fix
-        cat_samples = [
-            rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values()
-        ]
+        cat_rows = list(cat_to_rows.values())
+        cat_samples = [rows[: min(per_cat, len(rows))] for rows in cat_rows]
         # Round-robin interleave so the first N rows are balanced.
         interleaved: list[int] = []
         for i in range(per_cat):
             for samples in cat_samples:
                 if i < len(samples):
                     interleaved.append(samples[i])
-        return dataset.select(interleaved[:n])
+        if len(interleaved) < n:
+            leftovers = [rows[min(per_cat, len(rows)) :] for rows in cat_rows]
+            max_leftover = max((len(rows) for rows in leftovers), default=0)
+            for i in range(max_leftover):
+                for rows in leftovers:
+                    if i < len(rows):
+                        interleaved.append(rows[i])
+                        if len(interleaved) == n:
+                            return dataset.select(interleaved)
+        return dataset.select(interleaved[:n])
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/specdec_bench/specdec_bench/datasets/speed.py` around lines 768 -
777, _stratified_select can under-select when some categories have fewer than
per_cat, producing interleaved shorter than n; fix by after building interleaved
from cat_samples, check if len(interleaved) < n and if so append additional row
indices from the original dataset (skipping indices already in interleaved)
until you reach n, then call dataset.select on the first n indices. Locate the
logic around variables per_cat, cat_samples, interleaved in _stratified_select
and update it to compute a set of chosen indices, iterate the dataset index
range to add missing indices (or flatten remaining rows from cat_to_rows) until
n is reached, and only then return dataset.select(interleaved[:n]).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@examples/specdec_bench/specdec_bench/datasets/speed.py`:
- Around line 768-777: _stratified_select can under-select when some categories
have fewer than per_cat, producing interleaved shorter than n; fix by after
building interleaved from cat_samples, check if len(interleaved) < n and if so
append additional row indices from the original dataset (skipping indices
already in interleaved) until you reach n, then call dataset.select on the first
n indices. Locate the logic around variables per_cat, cat_samples, interleaved
in _stratified_select and update it to compute a set of chosen indices, iterate
the dataset index range to add missing indices (or flatten remaining rows from
cat_to_rows) until n is reached, and only then return
dataset.select(interleaved[:n]).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9fb67722-71b3-4ec0-8de7-ded2a88df76b

📥 Commits

Reviewing files that changed from the base of the PR and between 9dbc4fc and cb8d44a.

📒 Files selected for processing (3)
  • examples/specdec_bench/run.py
  • examples/specdec_bench/specdec_bench/datasets/speed.py
  • examples/specdec_bench/specdec_bench/models/vllm.py
✅ Files skipped from review due to trivial changes (1)
  • examples/specdec_bench/specdec_bench/models/vllm.py

@codecov
Copy link
Copy Markdown

codecov Bot commented May 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 77.29%. Comparing base (06ef935) to head (cb8d44a).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1388      +/-   ##
==========================================
+ Coverage   76.90%   77.29%   +0.39%     
==========================================
  Files         471      471              
  Lines       50565    50565              
==========================================
+ Hits        38885    39086     +201     
+ Misses      11680    11479     -201     
Flag Coverage Δ
examples 40.21% <ø> (-0.45%) ⬇️
unit 52.80% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@milesial milesial closed this May 4, 2026
@milesial milesial deleted the specdec-stratify-and-dp branch May 4, 2026 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant