Add ParameterizedScaleAutoNormal ADVI guide with per-site scale init#139
Add ParameterizedScaleAutoNormal ADVI guide with per-site scale init#139shaharbar1 wants to merge 1 commit intodevelopfrom
Conversation
|
ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (5)
💤 Files with no reviewable changes (2)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds ParameterizedScaleAutoNormal with per-site init_scale_fn, refactors SVI init to return init_loc_fn and init_scale_fn, wires the new guide into the SVI training loop for "advi", adds Hypothesis tests validating ADVI parameter extraction and zero-LR behavior, and includes minor CI/version/rule edits. ChangesPer-Site Scale ADVI Guide
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@pybandits/model.py`:
- Around line 2002-2004: The init_scale_fn currently returns exact sigmas (from
site_sigmas or avg_sigma) which can be zero and cause a zero-std initialization
for ParameterizedScaleAutoNormal / the full-rank scalar fallback; change
init_scale_fn to clamp returned scales to a small positive floor (e.g. epsilon =
1e-6 or 1e-3) so both avg_sigma and site_sigmas values are max(value, epsilon)
before being returned; update references to avg_sigma, site_sigmas, and the
lambda init_scale_fn used by BaseLocationScaleArray /
ParameterizedScaleAutoNormal to use this clamped value.
In `@pyproject.toml`:
- Line 3: The version string in pyproject.toml was bumped as a patch (6.0.9) but
this change adds features and must be a MINOR bump; update the version value
from "6.0.9" to "6.1.0" in pyproject.toml so the project semantic version
reflects added functionality (e.g., after introducing
ParameterizedScaleAutoNormal and refactors).
In `@tests/test_model.py`:
- Around line 1774-1782: The mean-tolerance uses sigma_init but should use the
learned posterior scale; change the mean absolute tolerance calculation so
atol_mu is computed from site_sigma[name] (e.g. atol_mu = n_sigma_tolerance *
site_sigma[name] / np.sqrt(n_predictive_samples)) and use that per-site inside
the loop before asserting np.testing.assert_allclose(site_mu[name],
np.mean(draw, axis=0), atol=atol_mu); keep the rtol_sigma calculation as-is and
reference symbols: atol_mu, site_sigma[name], n_sigma_tolerance,
n_predictive_samples, samples, site_mu.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 99a354a2-903b-42fa-9617-f2504fc8f224
📒 Files selected for processing (5)
.cursor/rules/test-cost-co.mdc.github/workflows/pull_request_style_check.ymlpybandits/model.pypyproject.tomltests/test_model.py
💤 Files with no reviewable changes (2)
- .cursor/rules/test-cost-co.mdc
- .github/workflows/pull_request_style_check.yml
| avg_sigma = float(np.mean(np.concatenate(all_sigmas))) | ||
| init_scale_fn = lambda name: site_sigmas.get(name, avg_sigma) # noqa: E731 | ||
| return init_loc_fn, init_scale_fn |
There was a problem hiding this comment.
Clamp guide init scales away from zero.
BaseLocationScaleArray accepts sigma=0, and this closure now forwards those exact values into both ParameterizedScaleAutoNormal and the full-rank scalar fallback. That makes otherwise valid deterministic priors initialize the variational Normal with a zero std, which can blow up SVI before the first update.
Suggested fix
- avg_sigma = float(np.mean(np.concatenate(all_sigmas)))
- init_scale_fn = lambda name: site_sigmas.get(name, avg_sigma) # noqa: E731
+ min_sigma = self._min_vi_sigma
+ avg_sigma = float(np.mean(np.concatenate([np.maximum(s, min_sigma) for s in all_sigmas])))
+ init_scale_fn = lambda name: np.maximum(site_sigmas.get(name, avg_sigma), min_sigma) # noqa: E731
return init_loc_fn, init_scale_fn🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@pybandits/model.py` around lines 2002 - 2004, The init_scale_fn currently
returns exact sigmas (from site_sigmas or avg_sigma) which can be zero and cause
a zero-std initialization for ParameterizedScaleAutoNormal / the full-rank
scalar fallback; change init_scale_fn to clamp returned scales to a small
positive floor (e.g. epsilon = 1e-6 or 1e-3) so both avg_sigma and site_sigmas
values are max(value, epsilon) before being returned; update references to
avg_sigma, site_sigmas, and the lambda init_scale_fn used by
BaseLocationScaleArray / ParameterizedScaleAutoNormal to use this clamped value.
8ac6a08 to
4633000
Compare
4633000 to
5f01aa7
Compare
### Changes:
- Add ParameterizedScaleAutoNormal — subclass of AutoNormal that accepts init_scale_fn(site_name) -> array, enabling per-parameter scale seeding instead
of a single global mean
- Refactor _build_svi_guide_init to return (init_loc_fn, init_scale_fn) with per-site sigma dict captured in closure; fullrank_advi retrieves the
avg-sigma fallback via init_scale_fn("")
- Wire ParameterizedScaleAutoNormal into _run_svi_training_loop for advi; fullrank_advi continues using AutoMultivariateNormal with scalar avg sigma
- Add test_advi_extracted_params_match_predictive_moments — verifies extracted (mu, sigma) match Predictive sample moments within N-sigma tolerances derived from sigma_init and n_predictive_samples
- Add test_advi_zero_lr_posterior_equals_prior — verifies that lr=0 SGD leaves posterior identical to prior; uses mu_init > 0 to guarantee init_to_value
(avoiding the stochastic init_to_median cold-start path)
- Removed single commit PR check. Shall be enforced via "Squash & Merge" from now on.
5f01aa7 to
f9ea7c3
Compare
Changes:
Summary by CodeRabbit
Chores
Refactor
Tests