refactor(training): split train_selfplay.py into focused modules#19
Merged
Conversation
First step of the train_selfplay.py split. Two pure helpers moved verbatim; no behavior change. The new src/jaxpot/training/ package will hold the per-phase modules in subsequent commits. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Submodules iteration and checkpointing are added by Tasks 3 and 2 of the train_selfplay split; documenting them before they exist creates dead references. The docstring grows back as each module lands. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pulls periodic/milestone/emergency save decisions out of train() into src/jaxpot/training/checkpointing.py. Each helper takes the state it needs explicitly — no shared context object — so the orchestrator stays readable as a top-down sequence. No behavior change; same save semantics, same paths, same prune rules. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pulls the four rollout-collection blocks (selfplay, random warmup, league, archive league) out of train() into a single collect_iteration_batches function with an IterationBatches / CollectionResult dataclass return. Timer key names (collect_selfplay, collect_random, collect_league, collect_archive_league) preserved exactly so dashboards don't break. Local variables in train() are kept as a temporary adapter; they collapse in the next commit when aggregate_iteration_metrics moves into iteration.py too. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_batches The previous commit (a7ceeb3) inadvertently restructured the per-iter RNG split from one 5-way split into a 2-way + nested 4-way split. Those produce different subkeys, breaking reproducibility for any run restoring across this commit boundary. Restore the original 5-way split in the orchestrator and have collect_iteration_batches take the four subkeys (k_self, k_rand, k_opp, k_arch) as separate keyword args. Bitwise equivalent to the pre-refactor PRNG behavior. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The function only ever used the pre-setup rollout_actor passed alongside it; the agent reference was scaffolded but not consumed by this or any planned downstream caller. Removing it shrinks the already-large keyword surface by one and drops the jaxpot.agents.BaseTrainingAgent import that became unused with it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aggregate_iteration_metrics produces host-side primitives under a single jax.device_get (timed as transfer_metrics, preserving the existing key). build_iteration_log_payload assembles the wandb / tensorboard payload with the same iteration/*, timings/* key naming as before. train_vs_random/reward flows through the structured IterationMetricsHost field instead of an inline dict mutation; orchestrator inserts it into log_payload at a slightly different point (after eval), same wire format. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous commit (65d3887) collapsed the two-stage log_payload build into a single late call to build_iteration_log_payload, but left maybe_save_periodic upstream of it. maybe_save_periodic reads log_payload.get(best_selection.metric) to gate best-checkpoint promotion; with the late build it only saw eval outputs and train_vs_random/reward, so any best_selection.metric pointing at an iteration/* or train-metric key silently failed to fire. Move maybe_save_periodic, maybe_save_milestone, and the inline league-freeze block to AFTER build_iteration_log_payload so they see the complete payload. checkpoint I/O time is no longer counted inside the "total" timer; that's a tiny observability change, not a correctness change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
train_selfplay.train() now reads as a top-down phase sequence: rollout -> train -> metrics+sync -> eval -> log -> persistence -> league-scheduling -> progress tick. build_training_components() owns all one-shot setup (device_puts, instantiate trainer/agent/evaluators, target network wiring, lr_schedule, progress + json_logger, recurrent- mode validation). train_selfplay.py shrinks from 562 lines to ~155. No behavior change; same RNG splits, same mesh blocks, same log key names. Side effect (checkpoint.model/optimizer mutated in place) preserved and documented in build_training_components. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… split make_env_fns moved to jaxpot.training.setup; update the only caller. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Catches typos, circular imports, and drift between the design spec's TrainingComponents / CollectionResult / IterationBatches / IterationMetricsHost field lists and the actual implementation. No behavior coverage — the underlying functions are already exercised indirectly by the existing tests. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
4 tasks
NorbertRop
added a commit
that referenced
this pull request
May 26, 2026
…23) * fix(rl): pure params alongside opt_state for optax MultiSteps interop train_selfplay's production optimizer wraps the base optax chain in optax.MultiSteps (see checkpoints.CheckpointManager.start). MultiSteps' update path does a tree.map(updates, state.acc_grads). Since PR #6 nnx.pure'd opt_state but left params Param-wrapped, updates (= grads from grad_fn) carried Param nodes that state.acc_grads (now bare) didn't — causing a Custom node type mismatch at trace time. The existing test_trainer_step and test_train_integration both use raw optax.sgd (no MultiSteps wrapper) and never hit this path, so the regression sat latent until the Hydra-composed integration test in the follow-up commit surfaced it. Pure params too so the jit body sees a uniform bare-array pytree. Write the bare results back into the model's Param containers via nnx.replace_by_pure_dict (same pattern league.restore uses), then the existing nnx.update flow merges with non_params and writes to model. Verified by running train_selfplay.py end-to-end on the colab tic_tac_toe profile for 1 iteration — full pipeline succeeds. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test(train): Hydra-composed integration smoke test for train() End-to-end test that composes the real config/train_selfplay.yaml (experiment=tic_tac_toe/colab + tiny overrides), then calls train_selfplay.train(...) for 2 iterations. Asserts the loop completes without exception and writes a well-formed checkpoint (metadata.json, state/, league.json, league_states/). Closes the coverage gap called out in #19: the existing test_train_integration.py exercises underlying pieces directly but bypasses Hydra config composition AND the production-shaped optimizer (MultiSteps + adamw + LR schedule). This test catches both surfaces. Runs in <30s on CPU. Patches HydraConfig.get() so the test reaches the train() body without needing a real Hydra runtime context. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Splits the 562-line
train_selfplay.pygod-script into three phase-grouped modules undersrc/jaxpot/training/plus a slimmed orchestrator at the repo root. Closes the train-selfplay-split follow-up item from #14.No algorithmic behavior change. Same RNG splits, same log key names, same checkpoint format, same Hydra CLI invocation.
Modules
src/jaxpot/training/setup.pymake_env_fns,build_aux_target_hooks,TrainingComponentsdataclass (16 fields),build_training_components(one-shot construction: device_puts, trainer/agent/evaluators/target wiring, lr_schedule, progress, json_logger).src/jaxpot/training/iteration.pyCollectionResult/IterationBatches/IterationMetricsHostdataclasses,collect_iteration_batches(4 rollout flavors under one call),aggregate_iteration_metrics(single per-iterjax.device_getsync),build_iteration_log_payload(wandb/tensorboard payload assembly).src/jaxpot/training/checkpointing.pysave_checkpoint_now,maybe_save_periodic,maybe_save_milestone,save_emergency— "when to save" decisions delegating disk writes toCheckpointManager/BestCheckpointManager.train_selfplay.py@mainentry + slim orchestrator loop. The loop reads top-to-bottom as the phase sequence: rollout → train → metrics+sync → eval → log → persistence → league-scheduling → progress tick.Five small inline blocks (target-net soft update, debug dump, league standings logging, league freeze scheduling, final tick) stay in the orchestrator — each is a 3-8 line decision against
itthat's clearer at the call site than as a module call.What did NOT change
iteration/*,timings/*,train_vs_random/reward, etc.)jax.random.split(key, 5)per iter — see commit2393975which fixed an earlier accidental restructure)collect_selfplay,collect_league, etc.)Observability caveats
Two small shifts in timing semantics — both intentional, neither affects training math:
timings/totalanditeration/spsno longer includemaybe_save_periodic,maybe_save_milestone, orleague.add_from_modelcost. The new orchestrator callstimer.stop(\"total\")before persistence and league-freeze, where the original ran them inside the "total" window. On iterations where any of those fire,iteration/spswill tick up slightly. Worth knowing for anyone running reproducibility comparisons or dashboard alerts on those keys.maybe_save_periodicnow sees the fulllog_payloadincludingtimings/*anditeration/spswhen itsbest_selection.metricis consulted. The original code built the payload incrementally and ran the save before the timings update, so timing-based best-checkpoint metrics silently never triggered. This is strictly more correct; if anyone configuredbest_selection.metric=\"iteration/sps\"it now actually works.Commit history
11 commits, all
type(scope): summary+ body. Three are clearly-labelledfix(training): ...commits that resolved issues caught by per-task review (RNG split preservation, dead parameter,log_payloadordering for the best-checkpoint lookup). No force-push, no merge commits.Test plan
uv run pytest tests/→ 215 passed, 2 skipped, 2 xfailed (was 212 + 2 + 2 baseline; +3 new smoke-import tests intests/test_training_module_imports.py).uv run python train_selfplay.py experiment=tic_tac_toe/fast logger=none total_iters=3should complete with the same metric output as before.Known limitation
The orchestrator's
train()function itself has no automated integration test —tests/test_train_integration.pycalls the underlying pieces (collect_selfplay,PPOAgent) directly, bypassing the Hydra entry point. The new smoke test validates module structure and public API shape. Full orchestrator coverage would require a Hydra-config fixture and is out of scope for a mechanical refactor.🤖 Generated with Claude Code