Skip to content

refactor(training): split train_selfplay.py into focused modules#19

Merged
NorbertRop merged 11 commits into
mainfrom
refactor/train-selfplay-split
May 26, 2026
Merged

refactor(training): split train_selfplay.py into focused modules#19
NorbertRop merged 11 commits into
mainfrom
refactor/train-selfplay-split

Conversation

@NorbertRop
Copy link
Copy Markdown
Contributor

Summary

Splits the 562-line train_selfplay.py god-script into three phase-grouped modules under src/jaxpot/training/ plus a slimmed orchestrator at the repo root. Closes the train-selfplay-split follow-up item from #14.

No algorithmic behavior change. Same RNG splits, same log key names, same checkpoint format, same Hydra CLI invocation.

Modules

File Owns
src/jaxpot/training/setup.py make_env_fns, build_aux_target_hooks, TrainingComponents dataclass (16 fields), build_training_components (one-shot construction: device_puts, trainer/agent/evaluators/target wiring, lr_schedule, progress, json_logger).
src/jaxpot/training/iteration.py CollectionResult/IterationBatches/IterationMetricsHost dataclasses, collect_iteration_batches (4 rollout flavors under one call), aggregate_iteration_metrics (single per-iter jax.device_get sync), build_iteration_log_payload (wandb/tensorboard payload assembly).
src/jaxpot/training/checkpointing.py save_checkpoint_now, maybe_save_periodic, maybe_save_milestone, save_emergency — "when to save" decisions delegating disk writes to CheckpointManager/BestCheckpointManager.
train_selfplay.py Hydra @main entry + slim orchestrator loop. The loop reads top-to-bottom as the phase sequence: rollout → train → metrics+sync → eval → log → persistence → league-scheduling → progress tick.

Five small inline blocks (target-net soft update, debug dump, league standings logging, league freeze scheduling, final tick) stay in the orchestrator — each is a 3-8 line decision against it that's clearer at the call site than as a module call.

What did NOT change

  • wandb/tensorboard log key naming (iteration/*, timings/*, train_vs_random/reward, etc.)
  • Checkpoint format, paths, atomicity guarantees
  • RNG splitting strategy (single jax.random.split(key, 5) per iter — see commit 2393975 which fixed an earlier accidental restructure)
  • Collector signatures (collect_selfplay, collect_league, etc.)
  • Hydra config schema or CLI invocation

Observability caveats

Two small shifts in timing semantics — both intentional, neither affects training math:

  1. timings/total and iteration/sps no longer include maybe_save_periodic, maybe_save_milestone, or league.add_from_model cost. The new orchestrator calls timer.stop(\"total\") before persistence and league-freeze, where the original ran them inside the "total" window. On iterations where any of those fire, iteration/sps will tick up slightly. Worth knowing for anyone running reproducibility comparisons or dashboard alerts on those keys.
  2. maybe_save_periodic now sees the full log_payload including timings/* and iteration/sps when its best_selection.metric is consulted. The original code built the payload incrementally and ran the save before the timings update, so timing-based best-checkpoint metrics silently never triggered. This is strictly more correct; if anyone configured best_selection.metric=\"iteration/sps\" it now actually works.

Commit history

11 commits, all type(scope): summary + body. Three are clearly-labelled fix(training): ... commits that resolved issues caught by per-task review (RNG split preservation, dead parameter, log_payload ordering for the best-checkpoint lookup). No force-push, no merge commits.

Test plan

  • uv run pytest tests/215 passed, 2 skipped, 2 xfailed (was 212 + 2 + 2 baseline; +3 new smoke-import tests in tests/test_training_module_imports.py).
  • New smoke test imports every public name from each new module and asserts the dataclass field sets match the design spec — catches typos, circular imports, and field drift.
  • Manual CLI smoke (post-merge, when convenient): uv run python train_selfplay.py experiment=tic_tac_toe/fast logger=none total_iters=3 should complete with the same metric output as before.

Known limitation

The orchestrator's train() function itself has no automated integration test — tests/test_train_integration.py calls the underlying pieces (collect_selfplay, PPOAgent) directly, bypassing the Hydra entry point. The new smoke test validates module structure and public API shape. Full orchestrator coverage would require a Hydra-config fixture and is out of scope for a mechanical refactor.

🤖 Generated with Claude Code

NorbertRop and others added 11 commits May 25, 2026 23:22
First step of the train_selfplay.py split. Two pure helpers moved
verbatim; no behavior change. The new src/jaxpot/training/ package
will hold the per-phase modules in subsequent commits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Submodules iteration and checkpointing are added by Tasks 3 and 2 of
the train_selfplay split; documenting them before they exist creates
dead references. The docstring grows back as each module lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pulls periodic/milestone/emergency save decisions out of train() into
src/jaxpot/training/checkpointing.py. Each helper takes the state it
needs explicitly — no shared context object — so the orchestrator
stays readable as a top-down sequence.

No behavior change; same save semantics, same paths, same prune rules.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pulls the four rollout-collection blocks (selfplay, random warmup,
league, archive league) out of train() into a single
collect_iteration_batches function with an IterationBatches /
CollectionResult dataclass return.

Timer key names (collect_selfplay, collect_random, collect_league,
collect_archive_league) preserved exactly so dashboards don't break.
Local variables in train() are kept as a temporary adapter; they
collapse in the next commit when aggregate_iteration_metrics moves
into iteration.py too.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_batches

The previous commit (a7ceeb3) inadvertently restructured the per-iter
RNG split from one 5-way split into a 2-way + nested 4-way split.
Those produce different subkeys, breaking reproducibility for any run
restoring across this commit boundary.

Restore the original 5-way split in the orchestrator and have
collect_iteration_batches take the four subkeys (k_self, k_rand,
k_opp, k_arch) as separate keyword args. Bitwise equivalent to the
pre-refactor PRNG behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The function only ever used the pre-setup rollout_actor passed
alongside it; the agent reference was scaffolded but not consumed by
this or any planned downstream caller. Removing it shrinks the
already-large keyword surface by one and drops the
jaxpot.agents.BaseTrainingAgent import that became unused with it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aggregate_iteration_metrics produces host-side primitives under a
single jax.device_get (timed as transfer_metrics, preserving the
existing key). build_iteration_log_payload assembles the wandb /
tensorboard payload with the same iteration/*, timings/* key naming
as before.

train_vs_random/reward flows through the structured IterationMetricsHost
field instead of an inline dict mutation; orchestrator inserts it
into log_payload at a slightly different point (after eval), same wire
format.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous commit (65d3887) collapsed the two-stage log_payload
build into a single late call to build_iteration_log_payload, but
left maybe_save_periodic upstream of it. maybe_save_periodic reads
log_payload.get(best_selection.metric) to gate best-checkpoint
promotion; with the late build it only saw eval outputs and
train_vs_random/reward, so any best_selection.metric pointing at
an iteration/* or train-metric key silently failed to fire.

Move maybe_save_periodic, maybe_save_milestone, and the inline
league-freeze block to AFTER build_iteration_log_payload so they
see the complete payload. checkpoint I/O time is no longer counted
inside the "total" timer; that's a tiny observability change, not
a correctness change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
train_selfplay.train() now reads as a top-down phase sequence:
rollout -> train -> metrics+sync -> eval -> log -> persistence ->
league-scheduling -> progress tick. build_training_components() owns
all one-shot setup (device_puts, instantiate trainer/agent/evaluators,
target network wiring, lr_schedule, progress + json_logger, recurrent-
mode validation).

train_selfplay.py shrinks from 562 lines to ~155. No behavior change;
same RNG splits, same mesh blocks, same log key names. Side effect
(checkpoint.model/optimizer mutated in place) preserved and documented
in build_training_components.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… split

make_env_fns moved to jaxpot.training.setup; update the only caller.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Catches typos, circular imports, and drift between the design spec's
TrainingComponents / CollectionResult / IterationBatches /
IterationMetricsHost field lists and the actual implementation.

No behavior coverage — the underlying functions are already exercised
indirectly by the existing tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@NorbertRop NorbertRop merged commit b06e6c3 into main May 26, 2026
1 check passed
@NorbertRop NorbertRop deleted the refactor/train-selfplay-split branch May 26, 2026 12:03
NorbertRop added a commit that referenced this pull request May 26, 2026
…23)

* fix(rl): pure params alongside opt_state for optax MultiSteps interop

train_selfplay's production optimizer wraps the base optax chain in
optax.MultiSteps (see checkpoints.CheckpointManager.start). MultiSteps'
update path does a tree.map(updates, state.acc_grads). Since PR #6
nnx.pure'd opt_state but left params Param-wrapped, updates (= grads
from grad_fn) carried Param nodes that state.acc_grads (now bare)
didn't — causing a Custom node type mismatch at trace time. The
existing test_trainer_step and test_train_integration both use raw
optax.sgd (no MultiSteps wrapper) and never hit this path, so the
regression sat latent until the Hydra-composed integration test
in the follow-up commit surfaced it.

Pure params too so the jit body sees a uniform bare-array pytree.
Write the bare results back into the model's Param containers via
nnx.replace_by_pure_dict (same pattern league.restore uses), then the
existing nnx.update flow merges with non_params and writes to model.

Verified by running train_selfplay.py end-to-end on the colab
tic_tac_toe profile for 1 iteration — full pipeline succeeds.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* test(train): Hydra-composed integration smoke test for train()

End-to-end test that composes the real config/train_selfplay.yaml
(experiment=tic_tac_toe/colab + tiny overrides), then calls
train_selfplay.train(...) for 2 iterations. Asserts the loop completes
without exception and writes a well-formed checkpoint (metadata.json,
state/, league.json, league_states/).

Closes the coverage gap called out in #19: the existing
test_train_integration.py exercises underlying pieces directly but
bypasses Hydra config composition AND the production-shaped optimizer
(MultiSteps + adamw + LR schedule). This test catches both surfaces.

Runs in <30s on CPU. Patches HydraConfig.get() so the test reaches the
train() body without needing a real Hydra runtime context.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant