From 880cbe9ff484ce9494dec4fbbbd5c08c3c857f09 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 25 May 2026 14:59:18 +0000 Subject: [PATCH] Add shuffle seed --- bergson/config.py | 6 +++++- bergson/magic/cli.py | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/bergson/config.py b/bergson/config.py index 4f5a0fae..9855129f 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -296,7 +296,11 @@ class TrainingConfig(AttributionConfig, Serializable): """Number of full passes over the training data.""" seed: int = 42 - """Random seed for dataset shuffling.""" + """Random seed for training and validation randomness.""" + + shuffle_seed: int | None = None + """Random seed for shuffling the training dataset. + """ adam_beta1: float = 0.95 """Beta1 for AdamW optimizer.""" diff --git a/bergson/magic/cli.py b/bergson/magic/cli.py index e14313ac..e3bbaa7d 100644 --- a/bergson/magic/cli.py +++ b/bergson/magic/cli.py @@ -601,8 +601,12 @@ def run_magic(run_cfg: TrainingConfig, *, score_path: str = ""): train_ds, train_n = setup_data_pipeline(run_cfg) train_ds = attach_doc_ids_if_missing(train_ds) - # Shuffle the train_ds with the seed. - train_ds = train_ds.shuffle(seed=run_cfg.seed) + # Shuffle the train_ds with the shuffle seed, defaulting to seed for + # backward compatibility with existing MAGIC configs. + shuffle_seed = ( + run_cfg.seed if run_cfg.shuffle_seed is None else run_cfg.shuffle_seed + ) + train_ds = train_ds.shuffle(seed=shuffle_seed) if isinstance(run_cfg, ValidationConfig): query_ds, query_n = setup_data_pipeline(run_cfg, run_cfg.query)