-
Notifications
You must be signed in to change notification settings - Fork 379
specdec_bench: stratify --num_requests #1388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| # limitations under the License. | ||
|
|
||
| # mypy: disable-error-code="index" | ||
| import math | ||
| import random | ||
| import re | ||
| from enum import Enum | ||
|
|
@@ -737,10 +738,44 @@ def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Data | |
| } | ||
| table = table.replace_schema_metadata(new_meta or None) | ||
| dataset = HFDataset(table) | ||
| if self.num_samples is not None: | ||
| dataset = dataset.select(range(self.num_samples)) | ||
| if self.num_samples is not None and self.num_samples < len(dataset): | ||
| dataset = self._stratified_select(dataset, self.num_samples) | ||
| return dataset | ||
|
|
||
| @staticmethod | ||
| def _stratified_select(dataset: "Dataset", n: int) -> "Dataset": | ||
| """Select ``n`` samples uniformly across the ``category`` column. | ||
|
|
||
| When ``category`` is present, each category contributes | ||
| ``ceil(n / num_categories)`` rows (capped by category size); the | ||
| result is truncated to exactly ``n`` rows by interleaving the | ||
| per-category samples round-robin so any further prefix slice | ||
| remains balanced. Falls back to ``range(n)`` when ``category`` is | ||
| absent. Indices come from ``range(category_size)`` (not random) | ||
| so behavior is deterministic. | ||
| """ | ||
| if "category" not in dataset.column_names: | ||
| return dataset.select(range(n)) | ||
| cat_to_rows: dict[str, list[int]] = {} | ||
| for i, c in enumerate(dataset["category"]): | ||
| cat_to_rows.setdefault(c, []).append(i) | ||
| num_cats = len(cat_to_rows) | ||
| if num_cats <= 1: | ||
| return dataset.select(range(n)) | ||
| per_cat = math.ceil(n / num_cats) | ||
| # Take the first ``per_cat`` rows from each category (parquet order | ||
| # within a category is treated as the canonical sample order). | ||
| cat_samples = [ | ||
| rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values() | ||
| ] | ||
| # Round-robin interleave so the first N rows are balanced. | ||
| interleaved: list[int] = [] | ||
| for i in range(per_cat): | ||
| for samples in cat_samples: | ||
| if i < len(samples): | ||
| interleaved.append(samples[i]) | ||
| return dataset.select(interleaved[:n]) | ||
|
Comment on lines
+765
to
+777
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When multiple categories are small, capping each at 💡 Proposed fix@@
- cat_samples = [
- rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values()
- ]
+ cat_rows = list(cat_to_rows.values())
+ cat_samples = [rows[: min(per_cat, len(rows))] for rows in cat_rows]
@@
interleaved: list[int] = []
for i in range(per_cat):
for samples in cat_samples:
if i < len(samples):
interleaved.append(samples[i])
- return dataset.select(interleaved[:n])
+ if len(interleaved) < n:
+ leftovers = [rows[min(per_cat, len(rows)) :] for rows in cat_rows]
+ max_leftover = max((len(rows) for rows in leftovers), default=0)
+ for i in range(max_leftover):
+ for rows in leftovers:
+ if i < len(rows):
+ interleaved.append(rows[i])
+ if len(interleaved) == n:
+ return dataset.select(interleaved)
+ return dataset.select(interleaved[:n])🤖 Prompt for AI Agents |
||
|
|
||
| def _resolve_external_data( | ||
| self, dataset: "Dataset", speed_config: config_type | str | ||
| ) -> "Dataset": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing CLI arg causes runtime crash on Line 158.
args.data_parallel_sizeis used, but--data_parallel_sizeis not defined in this parser, so CLI execution can fail withAttributeError.💡 Proposed fix
📝 Committable suggestion
🤖 Prompt for AI Agents