Skip to content

fix(rl): unbreak MultiSteps + add Hydra integration test for train()#23

Merged
NorbertRop merged 2 commits into
mainfrom
tests/train-hydra-integration
May 26, 2026
Merged

fix(rl): unbreak MultiSteps + add Hydra integration test for train()#23
NorbertRop merged 2 commits into
mainfrom
tests/train-hydra-integration

Conversation

@NorbertRop
Copy link
Copy Markdown
Contributor

Summary

Discovers and fixes a latent production regression while closing the orchestrator-coverage gap called out in #19.

The bug

CheckpointManager.start() always wraps the base optimizer in `optax.MultiSteps`. PR #6 nnx.pure'd `opt_state` for optax interop but left `params` Param-wrapped. MultiSteps' update path does `jax.tree.map(updates, state.acc_grads)` — and since `updates = grad_fn(params)` inherits params' Param wrappers while `state.acc_grads` was stripped by `nnx.pure`, the tree shapes don't align. The result: a `ValueError: Custom node type mismatch` at trace time on every `train_selfplay.py` run.

This sat latent because:

  • `test_trainer_step.py` uses raw `optax.sgd` (no MultiSteps wrapper).
  • `test_train_integration.py` uses raw `optax.sgd` too.
  • No test exercised the production optimizer construction (MultiSteps + adamw + LR schedule).

Confirmed independently by running `train_selfplay.py experiment=tic_tac_toe/colab` against `main` — same crash.

The fix

`nnx.pure` params alongside opt_state so the jit body sees a uniform bare-array pytree across params, grads, and opt_state's accumulators. Write the bare results back into model's Param containers via `nnx.replace_by_pure_dict` (same pattern `league.restore` uses).

`src/jaxpot/rl/trainer.py`, ~15 lines net change.

The test

`tests/test_train_hydra_integration.py` composes the real `config/train_selfplay.yaml` with `experiment=tic_tac_toe/colab` + tiny-scale overrides (16 envs, 2 iters), then calls `train_selfplay.train(...)` end-to-end. Asserts the loop completes and writes a well-formed checkpoint.

Runs in ~10s on CPU. Patches `HydraConfig.get()` to provide an output dir without a real Hydra runtime context.

Test plan

  • `uv run pytest tests/` — 227 passed, 2 skipped (was 226 + 1 new).
  • `tests/test_train_hydra_integration.py` passes.
  • All existing trainer tests (`test_trainer_step.py`, `test_train_integration.py`, `test_checkpoint_roundtrip.py`) still pass — proves the fix doesn't regress the simpler optimizer path.
  • Manual: `uv run python train_selfplay.py experiment=tic_tac_toe/colab logger=none total_iters=1 ...` completes end-to-end.

Commits

  1. `fix(rl): pure params alongside opt_state for optax MultiSteps interop`
  2. `test(train): Hydra-composed integration smoke test for train()`

The fix lands first so the test's regression-gate value is real: without commit 1, commit 2's test would have failed. With both, the test is now a permanent gate against future drift between params/opt_state structures.

🤖 Generated with Claude Code

NorbertRop and others added 2 commits May 26, 2026 12:22
train_selfplay's production optimizer wraps the base optax chain in
optax.MultiSteps (see checkpoints.CheckpointManager.start). MultiSteps'
update path does a tree.map(updates, state.acc_grads). Since PR #6
nnx.pure'd opt_state but left params Param-wrapped, updates (= grads
from grad_fn) carried Param nodes that state.acc_grads (now bare)
didn't — causing a Custom node type mismatch at trace time. The
existing test_trainer_step and test_train_integration both use raw
optax.sgd (no MultiSteps wrapper) and never hit this path, so the
regression sat latent until the Hydra-composed integration test
in the follow-up commit surfaced it.

Pure params too so the jit body sees a uniform bare-array pytree.
Write the bare results back into the model's Param containers via
nnx.replace_by_pure_dict (same pattern league.restore uses), then the
existing nnx.update flow merges with non_params and writes to model.

Verified by running train_selfplay.py end-to-end on the colab
tic_tac_toe profile for 1 iteration — full pipeline succeeds.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
End-to-end test that composes the real config/train_selfplay.yaml
(experiment=tic_tac_toe/colab + tiny overrides), then calls
train_selfplay.train(...) for 2 iterations. Asserts the loop completes
without exception and writes a well-formed checkpoint (metadata.json,
state/, league.json, league_states/).

Closes the coverage gap called out in #19: the existing
test_train_integration.py exercises underlying pieces directly but
bypasses Hydra config composition AND the production-shaped optimizer
(MultiSteps + adamw + LR schedule). This test catches both surfaces.

Runs in <30s on CPU. Patches HydraConfig.get() so the test reaches the
train() body without needing a real Hydra runtime context.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@NorbertRop NorbertRop merged commit 796a0d5 into main May 26, 2026
1 of 4 checks passed
@NorbertRop NorbertRop deleted the tests/train-hydra-integration branch May 26, 2026 13:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant