Conversation
…perature scaling ### Changes: * Add `num_particles`, `gradient_clip_norm`, and `kl_tau` to VI `update_kwargs` - `num_particles` is passed to `TraceMeanField_ELBO`/`Trace_ELBO` for multi-particle gradient estimates - `gradient_clip_norm` chains `optax.clip_by_global_norm` before the base optimizer - `kl_tau` scales prior log-prob by `tau * N_data / N_neurons` via `numpyro.handlers.scale` (Huix et al. 2022) * Switch all optimizers from numpyro-native to optax via `optax_to_numpyro`, removing the dual-path logic * Remove `clipped_adam`, `momentum`, and `reduce_on_plateau` scheduler * Remove `batch_size` parameter from `create_update_model`; read from `_update_kwargs` instead * Add `optax` to `pyproject.toml` dependencies
* Introduces `kl_annealing_fraction` as a new VI update kwarg for BayesianNeuralNetwork. When set, the KL divergence term in the ELBO is linearly warmed up from 0 → 1 over the first `kl_annealing_fraction * total_steps` training steps, then held at 1.0 for the remainder. * The annealing factor is threaded as a traced JAX scalar through the `model` and `guide` callables so the warm-up schedule runs entirely inside `lax.scan` without Python-loop overhead or recompilation. The factor composes multiplicatively with the existing `kl_tau` temperature scale when both are active. * Validation rejects values outside [0, 1]. Hypothesis-based tests cover the full cross-product with existing VI options (num_particles, gradient_clip_norm, kl_tau, lr_scheduler).
📝 WalkthroughWalkthroughAdds optax-based optimizer construction and LR-scheduler/gradient-clipping support, extends VI config with KL control fields and validation, implements per-step KL annealing in the NumPyro SVI training loop, updates guide/model builder usage, bumps package version to 6.1.0 and adds optax dependency; tests updated accordingly. Changes
Sequence DiagramsequenceDiagram
participant Trainer as Training Loop
participant SVI as SVI Engine
participant Model as NumPyro Model/Guide
participant KL as KL Scale Handler
participant Optax as Optax Scheduler/Optimizer
Trainer->>Optax: Build optimizer (+scheduler, clipping)
Trainer->>SVI: Init SVI with raw_guide, optax->numpyro optimizer
Trainer->>Trainer: Compute per-step kl_annealing_factor array
Trainer->>SVI: For each step (jax.lax.scan) call svi.update(batch, kl_annealing_factor)
SVI->>Model: Trace model/guide with kl_annealing_factor
Model->>KL: Apply scaling to prior/kl sites using numpyro.handlers.scale
Model->>SVI: Return loss contribution
SVI-->>Trainer: Return loss scalar
Trainer->>Trainer: Update params / early-stop checks
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 0/1 reviews remaining, refill in 60 minutes.Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (4)
tests/test_model.py (3)
1849-1852:assert_array_equalon float losses across two SVI runs is brittle.Even though both branches follow the same code path (
kl_annealing_active = False), comparingbnn_none._approx_historyandbnn_zero._approx_historywithassert_array_equalrequires bit-identical floats across two independent JAX traces/compiles. JAX/XLA generally honors this for the same code on the same device, but trivial differences (e.g. CPU vs. GPU CI runners, future XLA layout changes, parallel reduction orderings) can produce ULP-level deltas and make the test flaky.
np.testing.assert_allclose(..., rtol=0, atol=0)is identical in semantics today but lets you loosen later without rewriting the assertion; alternativelyassert_allclose(..., rtol=1e-6)would be a safer floor.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 1849 - 1852, Replace the brittle bit-exact comparison between bnn_none._approx_history and bnn_zero._approx_history with a tolerance-based comparison: change the np.testing.assert_array_equal(...) call to np.testing.assert_allclose(...) (or assert_allclose(..., rtol=1e-6, atol=0) / at least assert_allclose(..., rtol=0, atol=0) to allow future loosening) so floating-point ULP differences from separate JAX SVI runs won’t make the test flaky.
1822-1894: KL annealing tests don't exercise the probability path.These four tests cover construction, equivalence of inactive paths, divergence vs.
kl_tau-only, and serialization — but none of them callbnn.sample_proba(...)afterbnn.update(...)to verify the trained model still produces well-formed[0, 1]probabilities. Since KL annealing directly affects the posterior the predictive distribution is drawn from, a smoke check on the probability output would catch regressions where annealing leavesmu/sigmain a bad state (NaN, infinite, out-of-range).Consider augmenting at least
test_kl_annealing_interacts_with_kl_tau(and ideally the round-trip test) with the same probability/finiteness assertions used intest_vi_training_options(lines 900-902):result = bnn_tau_anneal.sample_proba(context=context) assert all(0 <= p[0] <= 1 for p in result) assert all(np.isfinite(p[1]) for p in result)As per coding guidelines:
**/*test*.py: Always apply callable cost and probability logic testing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 1822 - 1894, Add probability/finiteness smoke checks after updating models in the KL-annealing tests: call sample_proba on the trained instances (e.g., bnn_tau_anneal and bnn_tau_only in test_kl_annealing_interacts_with_kl_tau, and bnn or bnn2 in test_kl_annealing_fraction_serialization_round_trip) and assert outputs are valid probabilities and finite (e.g., each p in result has 0<=p[0]<=1 and p[1] is finite). This mirrors the checks in test_vi_training_options and ensures kl_annealing doesn't produce NaN/inf or out-of-range probabilities. Ensure you call sample_proba after the update() or after reconstruction so the smoke check validates the trained/post-serialization model.
793-811:test_lr_scheduler_validonly checks the optimizer is constructed; no end-to-end exercise.The test confirms
bnn._obj_optimizer is not Nonebut never runsbnn.update(...)orbnn.sample_proba(...). That means a scheduler that builds successfully but produces NaN gradients, mis-shaped learning rates, or otherwise breaks SVI updates would still pass. Given the coding guidelines (**/*test*.py: Always apply callable cost and probability logic testing) and the fact thattest_vi_training_optionsalready exercises a tiny update +sample_probaround-trip, mirroring that pattern here would noticeably strengthen the LR-scheduler coverage.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 793 - 811, Extend test_lr_scheduler_valid to perform a small end-to-end exercise after constructing the BNN: call BayesianNeuralNetwork.cold_start as currently done, then run a minimal update (e.g., bnn.update with a tiny synthetic batch or the same fixtures used in test_vi_training_options) and invoke bnn.sample_proba on a small input to ensure training and inference succeed; assert the update returns a finite loss (not NaN/inf) and that sample_proba returns probabilities of the expected shape and finite values. This uses the existing test_lr_scheduler_valid, bnn.update, and bnn.sample_proba symbols so the scheduler is validated during an actual SVI step and inference round-trip.pybandits/model.py (1)
1530-1530:n_samples: PositiveIntis misleading — callers pass a traced JAX scalar.
_create_update_model.modelinvokeskl_scale(n_samples, kl_annealing_factor)withn_samples = x.shape[0], which is a traced JAX integer when the function is JIT-compiled (e.g. insidelax.scan). A traced scalar is not a Pythonintand won't satisfy anyPositiveIntruntime validator. There is novalidate_calldecorator here so it does not actually fail today, but the annotation is misleading and will trip up anyone who later adds Pydantic validation.Consider widening the annotation to something like
Union[PositiveInt, jax.Array]or simplyintwith a comment noting traced scalars are accepted.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pybandits/model.py` at line 1530, The type annotation for _kl_scale_ctx's parameter n_samples is misleading because callers pass JAX-traced scalars (e.g. from _create_update_model.model calling kl_scale with x.shape[0]), so change the annotation on def _kl_scale_ctx to accept traced JAX values—either widen it to Union[PositiveInt, jax.Array] or to plain int—and add a short comment on the parameter stating that JAX traced scalars (jax.Array) are allowed to avoid future Pydantic/validation surprises; ensure related uses such as kl_scale and calls in _create_update_model.model remain type-compatible.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pybandits/model.py`:
- Around line 2037-2048: The kl annealing currently computes kl_annealing_factor
in _svi_body as (step.astype(jnp.float32) + 1.0) / kl_warmup_f which makes the
first step > 0; change the formula to (step.astype(jnp.float32)) / kl_warmup_f
so the first step yields 0 and the factor linearly reaches 1.0 after kl_warmup_f
steps; update the expression computing kl_annealing_factor in _svi_body (which
is passed into svi.update) accordingly and ensure jnp.minimum(1.0, ...) is still
applied.
- Around line 2056-2060: The JIT for _run_epoch currently marks the epoch length
argument `n` static (static_argnums=(1,)), so when the final epoch uses a
different step count (because epoch_steps_list can include a smaller `remaining`
epoch) it triggers an extra XLA compile; change to use a non-static,
numeric-length argument or make all epochs the same length. Concretely: either
remove `n` from static_argnums on the jax.jit of _run_epoch (so pass an
int32/array length like the active path does) or normalize `epoch_steps_list`
(drop/absorb the remainder so every entry equals steps_per_epoch) so _run_epoch
is always called with the same `n`; update references to _run_epoch, _svi_body,
epoch_steps_list, num_steps and steps_per_epoch accordingly.
- Around line 1135-1150: The current annotation-based checks in
_resolve_optax_fn (which references self._optax_return_types and optax types
like optax.Schedule / GradientTransformationExtraArgs) are brittle; change the
function to first tolerate missing or string annotations (treat
inspect.Signature.empty or str annotations as "unreliable") and then perform a
pragmatic duck-typed validation fallback: try to call the candidate fn with
small dummy inputs appropriate to the kind (e.g., for "schedule" call fn(0) or
fn(learning_rate=0.01) for optimizer factories) and verify the returned object
supports the expected protocol (e.g., has .init and .update for
GradientTransformation-like objects, or is callable/returns numbers for
schedules), catching any exceptions and raising a clear ValueError indicating
either "missing/unparseable annotation" or "callable did not produce expected
interface" and returning fn on success; keep using the name checks and callable
guard but remove strict issubclass/== checks against return_annotation.
---
Nitpick comments:
In `@pybandits/model.py`:
- Line 1530: The type annotation for _kl_scale_ctx's parameter n_samples is
misleading because callers pass JAX-traced scalars (e.g. from
_create_update_model.model calling kl_scale with x.shape[0]), so change the
annotation on def _kl_scale_ctx to accept traced JAX values—either widen it to
Union[PositiveInt, jax.Array] or to plain int—and add a short comment on the
parameter stating that JAX traced scalars (jax.Array) are allowed to avoid
future Pydantic/validation surprises; ensure related uses such as kl_scale and
calls in _create_update_model.model remain type-compatible.
In `@tests/test_model.py`:
- Around line 1849-1852: Replace the brittle bit-exact comparison between
bnn_none._approx_history and bnn_zero._approx_history with a tolerance-based
comparison: change the np.testing.assert_array_equal(...) call to
np.testing.assert_allclose(...) (or assert_allclose(..., rtol=1e-6, atol=0) / at
least assert_allclose(..., rtol=0, atol=0) to allow future loosening) so
floating-point ULP differences from separate JAX SVI runs won’t make the test
flaky.
- Around line 1822-1894: Add probability/finiteness smoke checks after updating
models in the KL-annealing tests: call sample_proba on the trained instances
(e.g., bnn_tau_anneal and bnn_tau_only in
test_kl_annealing_interacts_with_kl_tau, and bnn or bnn2 in
test_kl_annealing_fraction_serialization_round_trip) and assert outputs are
valid probabilities and finite (e.g., each p in result has 0<=p[0]<=1 and p[1]
is finite). This mirrors the checks in test_vi_training_options and ensures
kl_annealing doesn't produce NaN/inf or out-of-range probabilities. Ensure you
call sample_proba after the update() or after reconstruction so the smoke check
validates the trained/post-serialization model.
- Around line 793-811: Extend test_lr_scheduler_valid to perform a small
end-to-end exercise after constructing the BNN: call
BayesianNeuralNetwork.cold_start as currently done, then run a minimal update
(e.g., bnn.update with a tiny synthetic batch or the same fixtures used in
test_vi_training_options) and invoke bnn.sample_proba on a small input to ensure
training and inference succeed; assert the update returns a finite loss (not
NaN/inf) and that sample_proba returns probabilities of the expected shape and
finite values. This uses the existing test_lr_scheduler_valid, bnn.update, and
bnn.sample_proba symbols so the scheduler is validated during an actual SVI step
and inference round-trip.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 343828ea-f7fd-4f54-8772-c6257cdf38ff
📒 Files selected for processing (3)
pybandits/model.pypyproject.tomltests/test_model.py
| svi_state = svi.init(subkey, x_jnp, y_jnp, 1.0) | ||
| if kl_annealing_active: | ||
| kl_warmup_f = jnp.float32(kl_warmup_steps) | ||
|
|
||
| def _svi_body(state, step): # step is a traced int32 scalar (global index) | ||
| kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f) | ||
| state, loss = svi.update(state, x_jnp, y_jnp, kl_annealing_factor) | ||
| return state, loss | ||
|
|
||
| _run_epoch = jax.jit( | ||
| lambda state, steps: jax.lax.scan(_svi_body, state, steps), | ||
| ) # steps is a 1-D int32 array; scan feeds each element as xs to _svi_body |
There was a problem hiding this comment.
KL annealing factor never reaches 0; starts at 1/kl_warmup_steps.
The PR description states that the annealing factor "linearly warms the KL term in the ELBO from 0 to 1 over the first kl_annealing_fraction * total_steps steps". The current formula is
kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)So at step=0 the factor is 1/kl_warmup_steps (and at kl_warmup_steps == 1 the factor is already 1.0, i.e. effectively no warmup). If the goal is "from 0 to 1", drop the + 1.0 so the first step starts at exactly 0 and the factor reaches 1.0 only after kl_warmup_steps updates:
♻️ Suggested fix
- def _svi_body(state, step): # step is a traced int32 scalar (global index)
- kl_annealing_factor = jnp.minimum(1.0, (step.astype(jnp.float32) + 1.0) / kl_warmup_f)
+ def _svi_body(state, step): # step is a traced int32 scalar (global index)
+ kl_annealing_factor = jnp.minimum(1.0, step.astype(jnp.float32) / kl_warmup_f)
state, loss = svi.update(state, x_jnp, y_jnp, kl_annealing_factor)
return state, lossIf the off-by-one is intentional, please update the PR description / _kl_scale_ctx docstring to reflect that the factor is in (0, 1] and starts at 1/N.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pybandits/model.py` around lines 2037 - 2048, The kl annealing currently
computes kl_annealing_factor in _svi_body as (step.astype(jnp.float32) + 1.0) /
kl_warmup_f which makes the first step > 0; change the formula to
(step.astype(jnp.float32)) / kl_warmup_f so the first step yields 0 and the
factor linearly reaches 1.0 after kl_warmup_f steps; update the expression
computing kl_annealing_factor in _svi_body (which is passed into svi.update)
accordingly and ensure jnp.minimum(1.0, ...) is still applied.
| total_steps = sum(epoch_steps_list) | ||
| kl_annealing_fraction = self._update_kwargs.get("kl_annealing_fraction") | ||
| kl_annealing_active = kl_annealing_fraction not in (None, 0.0) | ||
| kl_warmup_steps = max(1, int(np.ceil(kl_annealing_fraction * total_steps))) if kl_annealing_active else None |
There was a problem hiding this comment.
I think you can drop the np.ceil and just directly apply int().
There was a problem hiding this comment.
I'd actually prefer to keep np.ceil here. With plain int(...) truncation we'd silently give fewer warmup steps than the user requested (e.g. 0.05 * 195 = 9.75 rounds to 9 steps instead of 10), and for small fractions on small training runs it can truncate to 0, leaning entirely on the max(1, ...) floor. Rounding up feels like the more honest reading of "warm up over kl_annealing_fraction * total_steps steps". Can drop it if you feel strongly though.
Model (pybandits/model.py):
* Unify the SVI training loop on a single precomputed kl_factors_full
array fed as xs to lax.scan, collapsing the active/inactive branch
fork (two _svi_body defs, two _run_epoch JITs, two epoch-caller
paths) into one. Drops static_argnums=(1,), removing the trailing-
partial-epoch XLA recompile.
* Always wrap raw_guide so the guide's call signature matches the
model's (svi.update unconditionally forwards kl_annealing_factor).
_kl_scale_ctx returns nullcontext() when both kl_tau and KL annealing
are inactive, so the wrapper is a no-op on the inactive path.
* Widen annotations on values that become traced scalars under JIT:
_kl_scale_ctx(n_samples: Union[PositiveInt, jax.Array],
kl_annealing_factor: Union[Float01, jax.Array] = 1.0); same Union
shape on the inner model and inner guide signatures.
* Fix _kl_scale_ctx docstring: kl_annealing_factor range is (0, 1],
not [0, 1] (lower bound is 1 / warmup_steps from the (step + 1) /
warmup_steps schedule).
Tests (tests/test_model.py):
* Group the KL annealing tests under a TestKLAnnealing class with a
shared N_FEATURES class const. Move the boundary-derived invalid
fractions to a module-level _INVALID_KL_ANNEALING_FRACTIONS tuple.
* Replace the brittle "None vs 0.0 produce identical SVI loss
histories" check with two trace-based tests using
numpyro.handlers.trace + numpyro.handlers.seed:
- test_kl_annealing_inactive_no_scale_on_prior_sites: parametrized
over (None, 0.0); asserts prior sample sites carry no scale
annotation when both gates are off.
- test_kl_annealing_active_scales_prior_sites: parametrized over a
few factor values; asserts the cumulative scale at prior sites
equals the runtime factor passed into model_fn.
These don't run SVI, so they're fast, deterministic, and free of
the seed-dependent flakiness the equivalence assertion exhibited.
* test_kl_annealing_interacts_with_kl_tau: drop the seed parametrize.
random_seed, kl_tau, and kl_annealing_fraction are sampled per run
from their valid ranges; only the divergence assertion is under
test.
* test_kl_annealing_fraction_serialization_round_trip: sample
kl_annealing_fraction per run instead of parametrizing over four
values.
* Generate synthetic context/rewards inline with bare np.random.* (no
seeding, no helper) to match the file-wide convention in
tests/test_model.py.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/test_model.py (1)
1866-1927: ⚡ Quick winAdd one regression against the generated warmup schedule.
These KL tests inject
factordirectly intomodel_fn, so they never exercise the factors built inside_run_svi_training_loop. A small assertion around the first few scan factors would make the0 → 1warmup contract explicit and would catch schedule bugs even when the prior-site trace tests still pass.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 1866 - 1927, Add a small regression asserting the warmup schedule produced inside the training loop increases from 0→1: call the model's internal SVI schedule generator (invoke BayesianNeuralNetwork._run_svi_training_loop or the method that returns the per-step KL scan factors using the same update_kwargs used in test_kl_annealing_interacts_with_kl_tau, capture the returned per-step kl_factors/scan schedule for num_steps=10 and kl_annealing_fraction=kl_annealing_fraction, then assert the first few factors start near 0 and monotonically increase (e.g. first factor ~=0, middle factor between 0 and 1, final factor ~=1) to make the 0→1 warmup explicit and catch schedule regressions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pybandits/model.py`:
- Around line 1462-1465: At construction time, add the same input validation
used for kl_annealing_fraction to the other VI knobs carried in update_kwargs:
ensure num_particles is an integer >= 1, gradient_clip_norm is either None or a
number >= 0, and kl_tau is either None or a positive number > 0; if any check
fails, raise a ValueError with a clear message (similar style to the existing
kl_annealing_fraction check). Locate the validation near where
kl_annealing_fraction is checked (the constructor code that reads update_kwargs)
and perform these checks before forwarding values into NumPyro/optax.
---
Nitpick comments:
In `@tests/test_model.py`:
- Around line 1866-1927: Add a small regression asserting the warmup schedule
produced inside the training loop increases from 0→1: call the model's internal
SVI schedule generator (invoke BayesianNeuralNetwork._run_svi_training_loop or
the method that returns the per-step KL scan factors using the same
update_kwargs used in test_kl_annealing_interacts_with_kl_tau, capture the
returned per-step kl_factors/scan schedule for num_steps=10 and
kl_annealing_fraction=kl_annealing_fraction, then assert the first few factors
start near 0 and monotonically increase (e.g. first factor ~=0, middle factor
between 0 and 1, final factor ~=1) to make the 0→1 warmup explicit and catch
schedule regressions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1128e9a1-7cc9-4855-9642-9400dca00f0a
📒 Files selected for processing (2)
pybandits/model.pytests/test_model.py
| kl_annealing_fraction = self.update_kwargs.get("kl_annealing_fraction") | ||
| if kl_annealing_fraction is not None and not (0.0 <= kl_annealing_fraction <= 1.0): | ||
| raise ValueError(f"kl_annealing_fraction must be in [0, 1] or None, got {kl_annealing_fraction}") | ||
|
|
There was a problem hiding this comment.
Validate the other new VI knobs at construction time.
kl_annealing_fraction is range-checked here, but num_particles, gradient_clip_norm, and kl_tau are still forwarded unchecked into NumPyro/optax. Inputs like num_particles=0 or negative kl_tau/gradient_clip_norm will either fail later with opaque library errors or silently turn the KL term into the wrong objective.
Suggested guardrails
kl_annealing_fraction = self.update_kwargs.get("kl_annealing_fraction")
if kl_annealing_fraction is not None and not (0.0 <= kl_annealing_fraction <= 1.0):
raise ValueError(f"kl_annealing_fraction must be in [0, 1] or None, got {kl_annealing_fraction}")
+
+ num_particles = self.update_kwargs.get("num_particles")
+ if num_particles < 1:
+ raise ValueError(f"num_particles must be >= 1, got {num_particles}")
+
+ gradient_clip_norm = self.update_kwargs.get("gradient_clip_norm")
+ if gradient_clip_norm is not None and gradient_clip_norm < 0:
+ raise ValueError(f"gradient_clip_norm must be >= 0, got {gradient_clip_norm}")
+
+ kl_tau = self.update_kwargs.get("kl_tau")
+ if kl_tau is not None and kl_tau < 0:
+ raise ValueError(f"kl_tau must be >= 0, got {kl_tau}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pybandits/model.py` around lines 1462 - 1465, At construction time, add the
same input validation used for kl_annealing_fraction to the other VI knobs
carried in update_kwargs: ensure num_particles is an integer >= 1,
gradient_clip_norm is either None or a number >= 0, and kl_tau is either None or
a positive number > 0; if any check fails, raise a ValueError with a clear
message (similar style to the existing kl_annealing_fraction check). Locate the
validation near where kl_annealing_fraction is checked (the constructor code
that reads update_kwargs) and perform these checks before forwarding values into
NumPyro/optax.
Add KL annealing to BNN VI training
Summary
kl_annealing_fractionas a new VIupdate_kwargsparameter forBayesianNeuralNetwork. When set, the KL divergence term in the ELBO is linearly warmed up from0 → 1over the firstkl_annealing_fraction * total_stepstraining steps, then held at1.0for the remainder.modelandguidecallables, so the warm-up schedule runs entirely insidelax.scanwith no Python-loop overhead or XLA recompilation per step.kl_tautemperature scale when both are active:effective_scale = kl_base_factor * kl_annealing_factor.kl_annealing_fraction=Noneand=0.0are both treated as inactive (no-op), validated at construction time.Test plan
[0, 1]Noneand0.0produce equivalent training behavior (both inactive paths)num_particles,gradient_clip_norm,kl_tau,lr_schedulerSummary by CodeRabbit
New Features
Chores
Tests