Skip to content

chore: silence flax/optax deprecation warnings in trainer + tests#26

Merged
NorbertRop merged 1 commit into
mainfrom
chore/deprecations-cleanup
May 26, 2026
Merged

chore: silence flax/optax deprecation warnings in trainer + tests#26
NorbertRop merged 1 commit into
mainfrom
chore/deprecations-cleanup

Conversation

@NorbertRop
Copy link
Copy Markdown
Contributor

Summary

Bundled cleanup of the deprecation warnings CI was flagging on every run. Three migrations, all mechanical, no behavior change.

What Where
`optax.global_norm` → `optax.tree.norm` `src/jaxpot/rl/trainer.py` (2 sites)
`nnx.State.merge` → `nnx.merge_state` `src/jaxpot/rl/trainer.py`, `tests/test_checkpoint_roundtrip.py`
`variable.value` → `variable[...]` (Array) / `.get_value()` / `.set_value()` (struct) `src/jaxpot/rl/{trainer,ppo_trainer,alphazero_trainer}.py` (~9 sites)

The `.value` migration distinguishes two cases per flax's deprecation notice:

  • `self.iterations` and `self.training_steps` wrap `jnp.zeros((), ...)` (Variable[Array]) → use `variable[...]` getter/setter.
  • `self._metrics_accumulator` wraps a `PyTreeNode` struct (Variable[non-Array]) → use `.get_value()` / `.set_value()`.

Warning count

  • Before: 100 warnings across the suite.
  • After: 9 warnings, all upstream and out of scope for this PR:
    • `distrax/_src/utils/math.py` — `jax.core.get_aval` deprecation. Fixed when distrax catches up to a newer JAX.
    • `pgx/core.py` — `player_id` argument to `observe` deprecated. Replacement isn't obvious and the test sites genuinely need per-player observation; deferred for a focused PR if/when pgx provides a migration path.

Test plan

  • `uv run pytest tests/` — 233 passed, 2 skipped, 9 warnings.

🤖 Generated with Claude Code

Three migrations in one PR, all mechanical:

- optax.global_norm -> optax.tree.norm (2 sites in trainer.py).
- nnx.State.merge -> nnx.merge_state (trainer.py + test_checkpoint_roundtrip).
- nnx.Variable.value -> variable[...] (for Variable[Array] — iterations,
  training_steps) or variable.get_value()/set_value() (for non-Array
  Variable types — the _metrics_accumulator that wraps a PyTreeNode).
  ~9 call sites across trainer.py, ppo_trainer.py, alphazero_trainer.py.

Warning count drops from 100 to 9. The remaining 9 are upstream
(distrax's jax.core.get_aval and pgx's player_id-in-observe), not ours
to fix in this PR.

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@NorbertRop NorbertRop merged commit 7b19502 into main May 26, 2026
1 check passed
@NorbertRop NorbertRop deleted the chore/deprecations-cleanup branch May 26, 2026 13:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant