diff --git a/bergson/config.py b/bergson/config.py index 4f5a0fae..9855129f 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -296,7 +296,11 @@ class TrainingConfig(AttributionConfig, Serializable): """Number of full passes over the training data.""" seed: int = 42 - """Random seed for dataset shuffling.""" + """Random seed for training and validation randomness.""" + + shuffle_seed: int | None = None + """Random seed for shuffling the training dataset. + """ adam_beta1: float = 0.95 """Beta1 for AdamW optimizer.""" diff --git a/bergson/magic/cli.py b/bergson/magic/cli.py index 37da0878..dae33800 100644 --- a/bergson/magic/cli.py +++ b/bergson/magic/cli.py @@ -606,8 +606,12 @@ def run_magic(run_cfg: TrainingConfig, *, score_path: str = ""): train_ds, train_n = setup_data_pipeline(run_cfg) train_ds = attach_doc_ids_if_missing(train_ds) - # Shuffle the train_ds with the seed. - train_ds = train_ds.shuffle(seed=run_cfg.seed) + # Shuffle the train_ds with the shuffle seed, defaulting to seed for + # backward compatibility with existing MAGIC configs. + shuffle_seed = ( + run_cfg.seed if run_cfg.shuffle_seed is None else run_cfg.shuffle_seed + ) + train_ds = train_ds.shuffle(seed=shuffle_seed) if isinstance(run_cfg, ValidationConfig): query_ds, query_n = setup_data_pipeline(run_cfg, run_cfg.query)