diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py index f4fbf06c0e..1a6cdb8238 100644 --- a/examples/specdec_bench/run.py +++ b/examples/specdec_bench/run.py @@ -155,6 +155,7 @@ def run_simple(args): draft_model_dir=args.draft_model_dir, speculative_num_steps=args.draft_length, tensor_parallel_size=args.tp_size, + data_parallel_size=args.data_parallel_size, moe_expert_parallel_size=args.ep_size, trust_remote_code=args.trust_remote_code, tokenizer_path=args.tokenizer, diff --git a/examples/specdec_bench/specdec_bench/datasets/speed.py b/examples/specdec_bench/specdec_bench/datasets/speed.py index fe544bb353..22879848eb 100644 --- a/examples/specdec_bench/specdec_bench/datasets/speed.py +++ b/examples/specdec_bench/specdec_bench/datasets/speed.py @@ -14,6 +14,7 @@ # limitations under the License. # mypy: disable-error-code="index" +import math import random import re from enum import Enum @@ -737,10 +738,44 @@ def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Data } table = table.replace_schema_metadata(new_meta or None) dataset = HFDataset(table) - if self.num_samples is not None: - dataset = dataset.select(range(self.num_samples)) + if self.num_samples is not None and self.num_samples < len(dataset): + dataset = self._stratified_select(dataset, self.num_samples) return dataset + @staticmethod + def _stratified_select(dataset: "Dataset", n: int) -> "Dataset": + """Select ``n`` samples uniformly across the ``category`` column. + + When ``category`` is present, each category contributes + ``ceil(n / num_categories)`` rows (capped by category size); the + result is truncated to exactly ``n`` rows by interleaving the + per-category samples round-robin so any further prefix slice + remains balanced. Falls back to ``range(n)`` when ``category`` is + absent. Indices come from ``range(category_size)`` (not random) + so behavior is deterministic. + """ + if "category" not in dataset.column_names: + return dataset.select(range(n)) + cat_to_rows: dict[str, list[int]] = {} + for i, c in enumerate(dataset["category"]): + cat_to_rows.setdefault(c, []).append(i) + num_cats = len(cat_to_rows) + if num_cats <= 1: + return dataset.select(range(n)) + per_cat = math.ceil(n / num_cats) + # Take the first ``per_cat`` rows from each category (parquet order + # within a category is treated as the canonical sample order). + cat_samples = [ + rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values() + ] + # Round-robin interleave so the first N rows are balanced. + interleaved: list[int] = [] + for i in range(per_cat): + for samples in cat_samples: + if i < len(samples): + interleaved.append(samples[i]) + return dataset.select(interleaved[:n]) + def _resolve_external_data( self, dataset: "Dataset", speed_config: config_type | str ) -> "Dataset": diff --git a/examples/specdec_bench/specdec_bench/models/vllm.py b/examples/specdec_bench/specdec_bench/models/vllm.py index 2e312e7aec..c7fc82c5c6 100644 --- a/examples/specdec_bench/specdec_bench/models/vllm.py +++ b/examples/specdec_bench/specdec_bench/models/vllm.py @@ -75,6 +75,7 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs tokenizer=kwargs.get("tokenizer_path"), trust_remote_code=kwargs.get("trust_remote_code", False), tensor_parallel_size=kwargs.get("tensor_parallel_size", 1), + data_parallel_size=kwargs.get("data_parallel_size", 1), enable_expert_parallel=kwargs.get("moe_expert_parallel_size", 1) > 1, enable_prefix_caching=kwargs.get("prefix_cache", False), speculative_config=specdec,