Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/specdec_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def run_simple(args):
draft_model_dir=args.draft_model_dir,
speculative_num_steps=args.draft_length,
tensor_parallel_size=args.tp_size,
data_parallel_size=args.data_parallel_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Missing CLI arg causes runtime crash on Line 158.

args.data_parallel_size is used, but --data_parallel_size is not defined in this parser, so CLI execution can fail with AttributeError.

💡 Proposed fix
@@
-        data_parallel_size=args.data_parallel_size,
+        data_parallel_size=getattr(args, "data_parallel_size", 1),
@@
     parser.add_argument(
         "--tp_size", type=int, required=False, default=4, help="Tensor parallel size"
     )
+    parser.add_argument(
+        "--data_parallel_size",
+        type=int,
+        required=False,
+        default=1,
+        help="Data parallel size",
+    )
     parser.add_argument(
         "--ep_size", type=int, required=False, default=2, help="Expert parallel size"
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_parallel_size=args.data_parallel_size,
data_parallel_size=getattr(args, "data_parallel_size", 1),
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/specdec_bench/run.py` at line 158, The code references
args.data_parallel_size in run.py (used when constructing the launcher via
data_parallel_size=args.data_parallel_size) but the CLI parser never defines
that flag; add a CLI argument for --data_parallel_size to the argparse parser
(or the function that builds it) so args has that attribute at runtime,
specifying type=int and a sensible default (e.g., 1) and a short help string;
alternatively ensure the call uses getattr(args, "data_parallel_size", 1) if you
prefer a fallback instead of adding a parser flag.

moe_expert_parallel_size=args.ep_size,
trust_remote_code=args.trust_remote_code,
tokenizer_path=args.tokenizer,
Expand Down
39 changes: 37 additions & 2 deletions examples/specdec_bench/specdec_bench/datasets/speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

# mypy: disable-error-code="index"
import math
import random
import re
from enum import Enum
Expand Down Expand Up @@ -737,10 +738,44 @@ def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Data
}
table = table.replace_schema_metadata(new_meta or None)
dataset = HFDataset(table)
if self.num_samples is not None:
dataset = dataset.select(range(self.num_samples))
if self.num_samples is not None and self.num_samples < len(dataset):
dataset = self._stratified_select(dataset, self.num_samples)
return dataset

@staticmethod
def _stratified_select(dataset: "Dataset", n: int) -> "Dataset":
"""Select ``n`` samples uniformly across the ``category`` column.

When ``category`` is present, each category contributes
``ceil(n / num_categories)`` rows (capped by category size); the
result is truncated to exactly ``n`` rows by interleaving the
per-category samples round-robin so any further prefix slice
remains balanced. Falls back to ``range(n)`` when ``category`` is
absent. Indices come from ``range(category_size)`` (not random)
so behavior is deterministic.
"""
if "category" not in dataset.column_names:
return dataset.select(range(n))
cat_to_rows: dict[str, list[int]] = {}
for i, c in enumerate(dataset["category"]):
cat_to_rows.setdefault(c, []).append(i)
num_cats = len(cat_to_rows)
if num_cats <= 1:
return dataset.select(range(n))
per_cat = math.ceil(n / num_cats)
# Take the first ``per_cat`` rows from each category (parquet order
# within a category is treated as the canonical sample order).
cat_samples = [
rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values()
]
# Round-robin interleave so the first N rows are balanced.
interleaved: list[int] = []
for i in range(per_cat):
for samples in cat_samples:
if i < len(samples):
interleaved.append(samples[i])
return dataset.select(interleaved[:n])
Comment on lines +765 to +777
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

_stratified_select can return fewer than n rows in imbalanced datasets.

When multiple categories are small, capping each at per_cat can make len(interleaved) < n, so dataset.select(interleaved[:n]) under-selects instead of honoring the requested sample count.

💡 Proposed fix
@@
-        cat_samples = [
-            rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values()
-        ]
+        cat_rows = list(cat_to_rows.values())
+        cat_samples = [rows[: min(per_cat, len(rows))] for rows in cat_rows]
@@
         interleaved: list[int] = []
         for i in range(per_cat):
             for samples in cat_samples:
                 if i < len(samples):
                     interleaved.append(samples[i])
-        return dataset.select(interleaved[:n])
+        if len(interleaved) < n:
+            leftovers = [rows[min(per_cat, len(rows)) :] for rows in cat_rows]
+            max_leftover = max((len(rows) for rows in leftovers), default=0)
+            for i in range(max_leftover):
+                for rows in leftovers:
+                    if i < len(rows):
+                        interleaved.append(rows[i])
+                        if len(interleaved) == n:
+                            return dataset.select(interleaved)
+        return dataset.select(interleaved[:n])
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/specdec_bench/specdec_bench/datasets/speed.py` around lines 765 -
777, _stratified_select can under-select when interleaved has fewer than n
items; change the function to detect if len(interleaved) < n after the
round-robin build and then fill the remainder by appending additional row
indices until n is reached (e.g., iterate remaining rows in cat_to_rows or
dataset indices skipping already-picked ones, or sample from the union of
leftover rows) before calling dataset.select(interleaved[:n]). Update references
in the function (_stratified_select, per_cat, cat_samples, interleaved) to
ensure duplicates are avoided and selection stops exactly at n.


def _resolve_external_data(
self, dataset: "Dataset", speed_config: config_type | str
) -> "Dataset":
Expand Down
1 change: 1 addition & 0 deletions examples/specdec_bench/specdec_bench/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
tokenizer=kwargs.get("tokenizer_path"),
trust_remote_code=kwargs.get("trust_remote_code", False),
tensor_parallel_size=kwargs.get("tensor_parallel_size", 1),
data_parallel_size=kwargs.get("data_parallel_size", 1),
enable_expert_parallel=kwargs.get("moe_expert_parallel_size", 1) > 1,
enable_prefix_caching=kwargs.get("prefix_cache", False),
speculative_config=specdec,
Expand Down
Loading