From cb8d44ae6efb3c2a81b942b2e85794c4452c1fed Mon Sep 17 00:00:00 2001 From: Alexandre Milesi Date: Mon, 4 May 2026 14:24:28 -0700 Subject: [PATCH] specdec_bench: stratify --num_requests, wire --data_parallel_size - speed.py: when --num_requests N is below the dataset size and the dataset has a `category` column with >1 distinct categories, take ceil(N / num_categories) rows from each category and round-robin interleave them so any prefix is balanced. Fixes a sampling bug where parquet rows are sorted by category, making `--num_requests 64` on throughput_8k pull 64 high_entropy prompts and zero from low_entropy / mixed. - run.py + models/vllm.py: forward --data_parallel_size to AsyncEngineArgs. The CLI flag was parsed but never passed to the engine, so DEP-style recipes only worked by accident (vLLM auto-DP from CUDA_VISIBLE_DEVICES). Explicit propagation makes DP=N reproducible and works under controlled CUDA_VISIBLE_DEVICES. Signed-off-by: Alexandre Milesi --- examples/specdec_bench/run.py | 1 + .../specdec_bench/datasets/speed.py | 39 ++++++++++++++++++- .../specdec_bench/models/vllm.py | 1 + 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py index f4fbf06c0e..1a6cdb8238 100644 --- a/examples/specdec_bench/run.py +++ b/examples/specdec_bench/run.py @@ -155,6 +155,7 @@ def run_simple(args): draft_model_dir=args.draft_model_dir, speculative_num_steps=args.draft_length, tensor_parallel_size=args.tp_size, + data_parallel_size=args.data_parallel_size, moe_expert_parallel_size=args.ep_size, trust_remote_code=args.trust_remote_code, tokenizer_path=args.tokenizer, diff --git a/examples/specdec_bench/specdec_bench/datasets/speed.py b/examples/specdec_bench/specdec_bench/datasets/speed.py index fe544bb353..22879848eb 100644 --- a/examples/specdec_bench/specdec_bench/datasets/speed.py +++ b/examples/specdec_bench/specdec_bench/datasets/speed.py @@ -14,6 +14,7 @@ # limitations under the License. # mypy: disable-error-code="index" +import math import random import re from enum import Enum @@ -737,10 +738,44 @@ def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Data } table = table.replace_schema_metadata(new_meta or None) dataset = HFDataset(table) - if self.num_samples is not None: - dataset = dataset.select(range(self.num_samples)) + if self.num_samples is not None and self.num_samples < len(dataset): + dataset = self._stratified_select(dataset, self.num_samples) return dataset + @staticmethod + def _stratified_select(dataset: "Dataset", n: int) -> "Dataset": + """Select ``n`` samples uniformly across the ``category`` column. + + When ``category`` is present, each category contributes + ``ceil(n / num_categories)`` rows (capped by category size); the + result is truncated to exactly ``n`` rows by interleaving the + per-category samples round-robin so any further prefix slice + remains balanced. Falls back to ``range(n)`` when ``category`` is + absent. Indices come from ``range(category_size)`` (not random) + so behavior is deterministic. + """ + if "category" not in dataset.column_names: + return dataset.select(range(n)) + cat_to_rows: dict[str, list[int]] = {} + for i, c in enumerate(dataset["category"]): + cat_to_rows.setdefault(c, []).append(i) + num_cats = len(cat_to_rows) + if num_cats <= 1: + return dataset.select(range(n)) + per_cat = math.ceil(n / num_cats) + # Take the first ``per_cat`` rows from each category (parquet order + # within a category is treated as the canonical sample order). + cat_samples = [ + rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values() + ] + # Round-robin interleave so the first N rows are balanced. + interleaved: list[int] = [] + for i in range(per_cat): + for samples in cat_samples: + if i < len(samples): + interleaved.append(samples[i]) + return dataset.select(interleaved[:n]) + def _resolve_external_data( self, dataset: "Dataset", speed_config: config_type | str ) -> "Dataset": diff --git a/examples/specdec_bench/specdec_bench/models/vllm.py b/examples/specdec_bench/specdec_bench/models/vllm.py index 2e312e7aec..c7fc82c5c6 100644 --- a/examples/specdec_bench/specdec_bench/models/vllm.py +++ b/examples/specdec_bench/specdec_bench/models/vllm.py @@ -75,6 +75,7 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs tokenizer=kwargs.get("tokenizer_path"), trust_remote_code=kwargs.get("trust_remote_code", False), tensor_parallel_size=kwargs.get("tensor_parallel_size", 1), + data_parallel_size=kwargs.get("data_parallel_size", 1), enable_expert_parallel=kwargs.get("moe_expert_parallel_size", 1) > 1, enable_prefix_caching=kwargs.get("prefix_cache", False), speculative_config=specdec,