Skip to content

Add KL annealing to BNN VI training #137

Open
Moritesh wants to merge 3 commits intodevelopfrom
feature/kl_annealing
Open

Add KL annealing to BNN VI training #137
Moritesh wants to merge 3 commits intodevelopfrom
feature/kl_annealing

Conversation

@Moritesh
Copy link
Copy Markdown
Collaborator

@Moritesh Moritesh commented Apr 26, 2026

Add KL annealing to BNN VI training

Note: This PR is based on the VI training enhancements from PR #132 (num_particles, gradient clipping, kl_tau, lr scheduler) and should be rebased & merged only after #132 lands on develop.

Summary

  • Introduces kl_annealing_fraction as a new VI update_kwargs parameter for BayesianNeuralNetwork. When set, the KL divergence term in the ELBO is linearly warmed up from 0 → 1 over the first kl_annealing_fraction * total_steps training steps, then held at 1.0 for the remainder.
  • The annealing factor is threaded as a traced JAX scalar through the model and guide callables, so the warm-up schedule runs entirely inside lax.scan with no Python-loop overhead or XLA recompilation per step.
  • Composes multiplicatively with the existing kl_tau temperature scale when both are active: effective_scale = kl_base_factor * kl_annealing_factor.
  • kl_annealing_fraction=None and =0.0 are both treated as inactive (no-op), validated at construction time.

Test plan

  • Validation rejects values outside [0, 1]
  • None and 0.0 produce equivalent training behavior (both inactive paths)
  • Hypothesis-based test covers the full cross-product with existing VI options: num_particles, gradient_clip_norm, kl_tau, lr_scheduler

Summary by CodeRabbit

  • New Features

    • Optax optimizer support with optional LR scheduling, gradient clipping, and configurable num_particles.
    • Per-step KL annealing for variational inference with kl_tau and kl_annealing_fraction controls.
  • Chores

    • Version bumped to 6.1.0.
    • Added Optax as a runtime dependency.
  • Tests

    • Expanded test coverage for optimizer/scheduler combinations, VI training options, and KL-annealing behavior.

shaharbar1 and others added 2 commits April 6, 2026 14:22
…perature scaling

### Changes:
* Add `num_particles`, `gradient_clip_norm`, and `kl_tau` to VI `update_kwargs`
  - `num_particles` is passed to `TraceMeanField_ELBO`/`Trace_ELBO` for multi-particle gradient estimates
  - `gradient_clip_norm` chains `optax.clip_by_global_norm` before the base optimizer
  - `kl_tau` scales prior log-prob by `tau * N_data / N_neurons` via `numpyro.handlers.scale` (Huix et al. 2022)
* Switch all optimizers from numpyro-native to optax via `optax_to_numpyro`, removing the dual-path logic
* Remove `clipped_adam`, `momentum`, and `reduce_on_plateau` scheduler
* Remove `batch_size` parameter from `create_update_model`; read from `_update_kwargs` instead
* Add `optax` to `pyproject.toml` dependencies
* Introduces `kl_annealing_fraction` as a new VI update kwarg for BayesianNeuralNetwork. When set, the KL divergence term in the ELBO is linearly warmed up from 0 → 1 over the first `kl_annealing_fraction * total_steps` training steps, then held at 1.0 for the remainder.
* The annealing factor is threaded as a traced JAX scalar through the `model` and `guide` callables so the warm-up schedule runs entirely inside `lax.scan` without Python-loop overhead or recompilation. The factor composes multiplicatively with the existing `kl_tau` temperature scale when both are active.
* Validation rejects values outside [0, 1]. Hypothesis-based tests cover the full cross-product with existing VI options (num_particles, gradient_clip_norm, kl_tau, lr_scheduler).
@Moritesh Moritesh requested a review from shaharbar1 April 26, 2026 14:39
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 26, 2026

📝 Walkthrough

Walkthrough

Adds optax-based optimizer construction and LR-scheduler/gradient-clipping support, extends VI config with KL control fields and validation, implements per-step KL annealing in the NumPyro SVI training loop, updates guide/model builder usage, bumps package version to 6.1.0 and adds optax dependency; tests updated accordingly.

Changes

Cohort / File(s) Summary
Core Model Implementation
pybandits/model.py
Replaces fixed NumPyro optimizer mapping with optax composition (_resolve_optax_fn, _get_obj_optimizer), adds LR-scheduler and gradient clipping support, extends VI config with num_particles, gradient_clip_norm, kl_tau, kl_annealing_fraction (with validation), implements _kl_scale_ctx and _create_update_model to apply KL scaling, updates SVI setup to run per-step KL annealing via jax.lax.scan, returns raw guide and adjusts early-stopping loss usage.
Project Configuration
pyproject.toml
Bumps package version 6.0.1 → 6.1.0 and adds optax = "^0.1" to dependencies.
Tests
tests/test_model.py
Switches tests to _create_update_model() usage, standardizes reward helper, extends categorical-BNN factory to accept update_kwargs, refactors MCMC test parametrization, and adds extensive tests for optax scheduler validity, VI option combinations (particles, clipping, KL params), KL-annealing behavior, and serialization preservation.

Sequence Diagram

sequenceDiagram
    participant Trainer as Training Loop
    participant SVI as SVI Engine
    participant Model as NumPyro Model/Guide
    participant KL as KL Scale Handler
    participant Optax as Optax Scheduler/Optimizer

    Trainer->>Optax: Build optimizer (+scheduler, clipping)
    Trainer->>SVI: Init SVI with raw_guide, optax->numpyro optimizer
    Trainer->>Trainer: Compute per-step kl_annealing_factor array
    Trainer->>SVI: For each step (jax.lax.scan) call svi.update(batch, kl_annealing_factor)
    SVI->>Model: Trace model/guide with kl_annealing_factor
    Model->>KL: Apply scaling to prior/kl sites using numpyro.handlers.scale
    Model->>SVI: Return loss contribution
    SVI-->>Trainer: Return loss scalar
    Trainer->>Trainer: Update params / early-stop checks
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 Optax hopped in with clipping and tune,
Schedules that sway beneath the moon,
KL gently grows, then finds its place,
SVI scans onward, keeping the pace,
Small hops, big leaps — a rabbit's embrace.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 65.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: introducing KL annealing to BNN VI training, which is clearly the primary objective of this PR.
Description check ✅ Passed The description covers the key objectives (kl_annealing_fraction parameter, behavior, composition with kl_tau, validation) and mentions test plan, though some test items are unchecked. It exceeds the template's minimal requirements.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/kl_annealing
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch feature/kl_annealing

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 0/1 reviews remaining, refill in 60 minutes.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (4)
tests/test_model.py (3)

1849-1852: assert_array_equal on float losses across two SVI runs is brittle.

Even though both branches follow the same code path (kl_annealing_active = False), comparing bnn_none._approx_history and bnn_zero._approx_history with assert_array_equal requires bit-identical floats across two independent JAX traces/compiles. JAX/XLA generally honors this for the same code on the same device, but trivial differences (e.g. CPU vs. GPU CI runners, future XLA layout changes, parallel reduction orderings) can produce ULP-level deltas and make the test flaky.

np.testing.assert_allclose(..., rtol=0, atol=0) is identical in semantics today but lets you loosen later without rewriting the assertion; alternatively assert_allclose(..., rtol=1e-6) would be a safer floor.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 1849 - 1852, Replace the brittle bit-exact
comparison between bnn_none._approx_history and bnn_zero._approx_history with a
tolerance-based comparison: change the np.testing.assert_array_equal(...) call
to np.testing.assert_allclose(...) (or assert_allclose(..., rtol=1e-6, atol=0) /
at least assert_allclose(..., rtol=0, atol=0) to allow future loosening) so
floating-point ULP differences from separate JAX SVI runs won’t make the test
flaky.

1822-1894: KL annealing tests don't exercise the probability path.

These four tests cover construction, equivalence of inactive paths, divergence vs. kl_tau-only, and serialization — but none of them call bnn.sample_proba(...) after bnn.update(...) to verify the trained model still produces well-formed [0, 1] probabilities. Since KL annealing directly affects the posterior the predictive distribution is drawn from, a smoke check on the probability output would catch regressions where annealing leaves mu/sigma in a bad state (NaN, infinite, out-of-range).

Consider augmenting at least test_kl_annealing_interacts_with_kl_tau (and ideally the round-trip test) with the same probability/finiteness assertions used in test_vi_training_options (lines 900-902):

result = bnn_tau_anneal.sample_proba(context=context)
assert all(0 <= p[0] <= 1 for p in result)
assert all(np.isfinite(p[1]) for p in result)

As per coding guidelines: **/*test*.py: Always apply callable cost and probability logic testing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 1822 - 1894, Add probability/finiteness
smoke checks after updating models in the KL-annealing tests: call sample_proba
on the trained instances (e.g., bnn_tau_anneal and bnn_tau_only in
test_kl_annealing_interacts_with_kl_tau, and bnn or bnn2 in
test_kl_annealing_fraction_serialization_round_trip) and assert outputs are
valid probabilities and finite (e.g., each p in result has 0<=p[0]<=1 and p[1]
is finite). This mirrors the checks in test_vi_training_options and ensures
kl_annealing doesn't produce NaN/inf or out-of-range probabilities. Ensure you
call sample_proba after the update() or after reconstruction so the smoke check
validates the trained/post-serialization model.

793-811: test_lr_scheduler_valid only checks the optimizer is constructed; no end-to-end exercise.

The test confirms bnn._obj_optimizer is not None but never runs bnn.update(...) or bnn.sample_proba(...). That means a scheduler that builds successfully but produces NaN gradients, mis-shaped learning rates, or otherwise breaks SVI updates would still pass. Given the coding guidelines (**/*test*.py: Always apply callable cost and probability logic testing) and the fact that test_vi_training_options already exercises a tiny update + sample_proba round-trip, mirroring that pattern here would noticeably strengthen the LR-scheduler coverage.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 793 - 811, Extend test_lr_scheduler_valid
to perform a small end-to-end exercise after constructing the BNN: call
BayesianNeuralNetwork.cold_start as currently done, then run a minimal update
(e.g., bnn.update with a tiny synthetic batch or the same fixtures used in
test_vi_training_options) and invoke bnn.sample_proba on a small input to ensure
training and inference succeed; assert the update returns a finite loss (not
NaN/inf) and that sample_proba returns probabilities of the expected shape and
finite values. This uses the existing test_lr_scheduler_valid, bnn.update, and
bnn.sample_proba symbols so the scheduler is validated during an actual SVI step
and inference round-trip.
pybandits/model.py (1)

1530-1530: n_samples: PositiveInt is misleading — callers pass a traced JAX scalar.

_create_update_model.model invokes kl_scale(n_samples, kl_annealing_factor) with n_samples = x.shape[0], which is a traced JAX integer when the function is JIT-compiled (e.g. inside lax.scan). A traced scalar is not a Python int and won't satisfy any PositiveInt runtime validator. There is no validate_call decorator here so it does not actually fail today, but the annotation is misleading and will trip up anyone who later adds Pydantic validation.

Consider widening the annotation to something like Union[PositiveInt, jax.Array] or simply int with a comment noting traced scalars are accepted.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` at line 1530, The type annotation for _kl_scale_ctx's
parameter n_samples is misleading because callers pass JAX-traced scalars (e.g.
from _create_update_model.model calling kl_scale with x.shape[0]), so change the
annotation on def _kl_scale_ctx to accept traced JAX values—either widen it to
Union[PositiveInt, jax.Array] or to plain int—and add a short comment on the
parameter stating that JAX traced scalars (jax.Array) are allowed to avoid
future Pydantic/validation surprises; ensure related uses such as kl_scale and
calls in _create_update_model.model remain type-compatible.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pybandits/model.py`:
- Around line 2037-2048: The kl annealing currently computes kl_annealing_factor
in _svi_body as (step.astype(jnp.float32) + 1.0) / kl_warmup_f which makes the
first step > 0; change the formula to (step.astype(jnp.float32)) / kl_warmup_f
so the first step yields 0 and the factor linearly reaches 1.0 after kl_warmup_f
steps; update the expression computing kl_annealing_factor in _svi_body (which
is passed into svi.update) accordingly and ensure jnp.minimum(1.0, ...) is still
applied.
- Around line 2056-2060: The JIT for _run_epoch currently marks the epoch length
argument `n` static (static_argnums=(1,)), so when the final epoch uses a
different step count (because epoch_steps_list can include a smaller `remaining`
epoch) it triggers an extra XLA compile; change to use a non-static,
numeric-length argument or make all epochs the same length. Concretely: either
remove `n` from static_argnums on the jax.jit of _run_epoch (so pass an
int32/array length like the active path does) or normalize `epoch_steps_list`
(drop/absorb the remainder so every entry equals steps_per_epoch) so _run_epoch
is always called with the same `n`; update references to _run_epoch, _svi_body,
epoch_steps_list, num_steps and steps_per_epoch accordingly.
- Around line 1135-1150: The current annotation-based checks in
_resolve_optax_fn (which references self._optax_return_types and optax types
like optax.Schedule / GradientTransformationExtraArgs) are brittle; change the
function to first tolerate missing or string annotations (treat
inspect.Signature.empty or str annotations as "unreliable") and then perform a
pragmatic duck-typed validation fallback: try to call the candidate fn with
small dummy inputs appropriate to the kind (e.g., for "schedule" call fn(0) or
fn(learning_rate=0.01) for optimizer factories) and verify the returned object
supports the expected protocol (e.g., has .init and .update for
GradientTransformation-like objects, or is callable/returns numbers for
schedules), catching any exceptions and raising a clear ValueError indicating
either "missing/unparseable annotation" or "callable did not produce expected
interface" and returning fn on success; keep using the name checks and callable
guard but remove strict issubclass/== checks against return_annotation.

---

Nitpick comments:
In `@pybandits/model.py`:
- Line 1530: The type annotation for _kl_scale_ctx's parameter n_samples is
misleading because callers pass JAX-traced scalars (e.g. from
_create_update_model.model calling kl_scale with x.shape[0]), so change the
annotation on def _kl_scale_ctx to accept traced JAX values—either widen it to
Union[PositiveInt, jax.Array] or to plain int—and add a short comment on the
parameter stating that JAX traced scalars (jax.Array) are allowed to avoid
future Pydantic/validation surprises; ensure related uses such as kl_scale and
calls in _create_update_model.model remain type-compatible.

In `@tests/test_model.py`:
- Around line 1849-1852: Replace the brittle bit-exact comparison between
bnn_none._approx_history and bnn_zero._approx_history with a tolerance-based
comparison: change the np.testing.assert_array_equal(...) call to
np.testing.assert_allclose(...) (or assert_allclose(..., rtol=1e-6, atol=0) / at
least assert_allclose(..., rtol=0, atol=0) to allow future loosening) so
floating-point ULP differences from separate JAX SVI runs won’t make the test
flaky.
- Around line 1822-1894: Add probability/finiteness smoke checks after updating
models in the KL-annealing tests: call sample_proba on the trained instances
(e.g., bnn_tau_anneal and bnn_tau_only in
test_kl_annealing_interacts_with_kl_tau, and bnn or bnn2 in
test_kl_annealing_fraction_serialization_round_trip) and assert outputs are
valid probabilities and finite (e.g., each p in result has 0<=p[0]<=1 and p[1]
is finite). This mirrors the checks in test_vi_training_options and ensures
kl_annealing doesn't produce NaN/inf or out-of-range probabilities. Ensure you
call sample_proba after the update() or after reconstruction so the smoke check
validates the trained/post-serialization model.
- Around line 793-811: Extend test_lr_scheduler_valid to perform a small
end-to-end exercise after constructing the BNN: call
BayesianNeuralNetwork.cold_start as currently done, then run a minimal update
(e.g., bnn.update with a tiny synthetic batch or the same fixtures used in
test_vi_training_options) and invoke bnn.sample_proba on a small input to ensure
training and inference succeed; assert the update returns a finite loss (not
NaN/inf) and that sample_proba returns probabilities of the expected shape and
finite values. This uses the existing test_lr_scheduler_valid, bnn.update, and
bnn.sample_proba symbols so the scheduler is validated during an actual SVI step
and inference round-trip.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 343828ea-f7fd-4f54-8772-c6257cdf38ff

📥 Commits

Reviewing files that changed from the base of the PR and between 7af2bb0 and 5a7421e.

📒 Files selected for processing (3)
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py

Comment thread pybandits/model.py
Comment thread pybandits/model.py Outdated
Comment on lines +2037 to +2048
svi_state = svi.init(subkey, x_jnp, y_jnp, 1.0)
if kl_annealing_active:
kl_warmup_f = jnp.float32(kl_warmup_steps)

def _svi_body(state, step): # step is a traced int32 scalar (global index)
kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)
state, loss = svi.update(state, x_jnp, y_jnp, kl_annealing_factor)
return state, loss

_run_epoch = jax.jit(
lambda state, steps: jax.lax.scan(_svi_body, state, steps),
) # steps is a 1-D int32 array; scan feeds each element as xs to _svi_body
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

KL annealing factor never reaches 0; starts at 1/kl_warmup_steps.

The PR description states that the annealing factor "linearly warms the KL term in the ELBO from 0 to 1 over the first kl_annealing_fraction * total_steps steps". The current formula is

kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)

So at step=0 the factor is 1/kl_warmup_steps (and at kl_warmup_steps == 1 the factor is already 1.0, i.e. effectively no warmup). If the goal is "from 0 to 1", drop the + 1.0 so the first step starts at exactly 0 and the factor reaches 1.0 only after kl_warmup_steps updates:

♻️ Suggested fix
-            def _svi_body(state, step):  # step is a traced int32 scalar (global index)
-                kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)
+            def _svi_body(state, step):  # step is a traced int32 scalar (global index)
+                kl_annealing_factor = jnp.minimum(1.0, step.astype(jnp.float32) / kl_warmup_f)
                 state, loss = svi.update(state, x_jnp, y_jnp, kl_annealing_factor)
                 return state, loss

If the off-by-one is intentional, please update the PR description / _kl_scale_ctx docstring to reflect that the factor is in (0, 1] and starts at 1/N.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 2037 - 2048, The kl annealing currently
computes kl_annealing_factor in _svi_body as (step.astype(jnp.float32) + 1.0) /
kl_warmup_f which makes the first step > 0; change the formula to
(step.astype(jnp.float32)) / kl_warmup_f so the first step yields 0 and the
factor linearly reaches 1.0 after kl_warmup_f steps; update the expression
computing kl_annealing_factor in _svi_body (which is passed into svi.update)
accordingly and ensure jnp.minimum(1.0, ...) is still applied.

Comment thread pybandits/model.py Outdated
Comment thread pybandits/model.py Outdated
Comment thread pybandits/model.py Outdated
total_steps = sum(epoch_steps_list)
kl_annealing_fraction = self._update_kwargs.get("kl_annealing_fraction")
kl_annealing_active = kl_annealing_fraction not in (None, 0.0)
kl_warmup_steps = max(1, int(np.ceil(kl_annealing_fraction * total_steps))) if kl_annealing_active else None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can drop the np.ceil and just directly apply int().

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd actually prefer to keep np.ceil here. With plain int(...) truncation we'd silently give fewer warmup steps than the user requested (e.g. 0.05 * 195 = 9.75 rounds to 9 steps instead of 10), and for small fractions on small training runs it can truncate to 0, leaning entirely on the max(1, ...) floor. Rounding up feels like the more honest reading of "warm up over kl_annealing_fraction * total_steps steps". Can drop it if you feel strongly though.

Comment thread pybandits/model.py Outdated
Comment thread pybandits/model.py Outdated
Comment thread pybandits/model.py Outdated
Comment thread tests/test_model.py Outdated
Comment thread tests/test_model.py Outdated
Model (pybandits/model.py):
  * Unify the SVI training loop on a single precomputed kl_factors_full
    array fed as xs to lax.scan, collapsing the active/inactive branch
    fork (two _svi_body defs, two _run_epoch JITs, two epoch-caller
    paths) into one. Drops static_argnums=(1,), removing the trailing-
    partial-epoch XLA recompile.
  * Always wrap raw_guide so the guide's call signature matches the
    model's (svi.update unconditionally forwards kl_annealing_factor).
    _kl_scale_ctx returns nullcontext() when both kl_tau and KL annealing
    are inactive, so the wrapper is a no-op on the inactive path.
  * Widen annotations on values that become traced scalars under JIT:
    _kl_scale_ctx(n_samples: Union[PositiveInt, jax.Array],
    kl_annealing_factor: Union[Float01, jax.Array] = 1.0); same Union
    shape on the inner model and inner guide signatures.
  * Fix _kl_scale_ctx docstring: kl_annealing_factor range is (0, 1],
    not [0, 1] (lower bound is 1 / warmup_steps from the (step + 1) /
    warmup_steps schedule).

Tests (tests/test_model.py):
  * Group the KL annealing tests under a TestKLAnnealing class with a
    shared N_FEATURES class const. Move the boundary-derived invalid
    fractions to a module-level _INVALID_KL_ANNEALING_FRACTIONS tuple.
  * Replace the brittle "None vs 0.0 produce identical SVI loss
    histories" check with two trace-based tests using
    numpyro.handlers.trace + numpyro.handlers.seed:
      - test_kl_annealing_inactive_no_scale_on_prior_sites: parametrized
        over (None, 0.0); asserts prior sample sites carry no scale
        annotation when both gates are off.
      - test_kl_annealing_active_scales_prior_sites: parametrized over a
        few factor values; asserts the cumulative scale at prior sites
        equals the runtime factor passed into model_fn.
    These don't run SVI, so they're fast, deterministic, and free of
    the seed-dependent flakiness the equivalence assertion exhibited.
  * test_kl_annealing_interacts_with_kl_tau: drop the seed parametrize.
    random_seed, kl_tau, and kl_annealing_fraction are sampled per run
    from their valid ranges; only the divergence assertion is under
    test.
  * test_kl_annealing_fraction_serialization_round_trip: sample
    kl_annealing_fraction per run instead of parametrizing over four
    values.
  * Generate synthetic context/rewards inline with bare np.random.* (no
    seeding, no helper) to match the file-wide convention in
    tests/test_model.py.
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/test_model.py (1)

1866-1927: ⚡ Quick win

Add one regression against the generated warmup schedule.

These KL tests inject factor directly into model_fn, so they never exercise the factors built inside _run_svi_training_loop. A small assertion around the first few scan factors would make the 0 → 1 warmup contract explicit and would catch schedule bugs even when the prior-site trace tests still pass.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 1866 - 1927, Add a small regression
asserting the warmup schedule produced inside the training loop increases from
0→1: call the model's internal SVI schedule generator (invoke
BayesianNeuralNetwork._run_svi_training_loop or the method that returns the
per-step KL scan factors using the same update_kwargs used in
test_kl_annealing_interacts_with_kl_tau, capture the returned per-step
kl_factors/scan schedule for num_steps=10 and
kl_annealing_fraction=kl_annealing_fraction, then assert the first few factors
start near 0 and monotonically increase (e.g. first factor ~=0, middle factor
between 0 and 1, final factor ~=1) to make the 0→1 warmup explicit and catch
schedule regressions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pybandits/model.py`:
- Around line 1462-1465: At construction time, add the same input validation
used for kl_annealing_fraction to the other VI knobs carried in update_kwargs:
ensure num_particles is an integer >= 1, gradient_clip_norm is either None or a
number >= 0, and kl_tau is either None or a positive number > 0; if any check
fails, raise a ValueError with a clear message (similar style to the existing
kl_annealing_fraction check). Locate the validation near where
kl_annealing_fraction is checked (the constructor code that reads update_kwargs)
and perform these checks before forwarding values into NumPyro/optax.

---

Nitpick comments:
In `@tests/test_model.py`:
- Around line 1866-1927: Add a small regression asserting the warmup schedule
produced inside the training loop increases from 0→1: call the model's internal
SVI schedule generator (invoke BayesianNeuralNetwork._run_svi_training_loop or
the method that returns the per-step KL scan factors using the same
update_kwargs used in test_kl_annealing_interacts_with_kl_tau, capture the
returned per-step kl_factors/scan schedule for num_steps=10 and
kl_annealing_fraction=kl_annealing_fraction, then assert the first few factors
start near 0 and monotonically increase (e.g. first factor ~=0, middle factor
between 0 and 1, final factor ~=1) to make the 0→1 warmup explicit and catch
schedule regressions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1128e9a1-7cc9-4855-9642-9400dca00f0a

📥 Commits

Reviewing files that changed from the base of the PR and between 5a7421e and 7b994fe.

📒 Files selected for processing (2)
  • pybandits/model.py
  • tests/test_model.py

Comment thread pybandits/model.py
Comment on lines +1462 to +1465
kl_annealing_fraction = self.update_kwargs.get("kl_annealing_fraction")
if kl_annealing_fraction is not None and not (0.0 <= kl_annealing_fraction <= 1.0):
raise ValueError(f"kl_annealing_fraction must be in [0, 1] or None, got {kl_annealing_fraction}")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate the other new VI knobs at construction time.

kl_annealing_fraction is range-checked here, but num_particles, gradient_clip_norm, and kl_tau are still forwarded unchecked into NumPyro/optax. Inputs like num_particles=0 or negative kl_tau/gradient_clip_norm will either fail later with opaque library errors or silently turn the KL term into the wrong objective.

Suggested guardrails
             kl_annealing_fraction = self.update_kwargs.get("kl_annealing_fraction")
             if kl_annealing_fraction is not None and not (0.0 <= kl_annealing_fraction <= 1.0):
                 raise ValueError(f"kl_annealing_fraction must be in [0, 1] or None, got {kl_annealing_fraction}")
+
+            num_particles = self.update_kwargs.get("num_particles")
+            if num_particles < 1:
+                raise ValueError(f"num_particles must be >= 1, got {num_particles}")
+
+            gradient_clip_norm = self.update_kwargs.get("gradient_clip_norm")
+            if gradient_clip_norm is not None and gradient_clip_norm < 0:
+                raise ValueError(f"gradient_clip_norm must be >= 0, got {gradient_clip_norm}")
+
+            kl_tau = self.update_kwargs.get("kl_tau")
+            if kl_tau is not None and kl_tau < 0:
+                raise ValueError(f"kl_tau must be >= 0, got {kl_tau}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 1462 - 1465, At construction time, add the
same input validation used for kl_annealing_fraction to the other VI knobs
carried in update_kwargs: ensure num_particles is an integer >= 1,
gradient_clip_norm is either None or a number >= 0, and kl_tau is either None or
a positive number > 0; if any check fails, raise a ValueError with a clear
message (similar style to the existing kl_annealing_fraction check). Locate the
validation near where kl_annealing_fraction is checked (the constructor code
that reads update_kwargs) and perform these checks before forwarding values into
NumPyro/optax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants