Skip to content

Add ParameterizedScaleAutoNormal ADVI guide with per-site scale init#139

Open
shaharbar1 wants to merge 1 commit intodevelopfrom
feature/improved_guide
Open

Add ParameterizedScaleAutoNormal ADVI guide with per-site scale init#139
shaharbar1 wants to merge 1 commit intodevelopfrom
feature/improved_guide

Conversation

@shaharbar1
Copy link
Copy Markdown
Collaborator

@shaharbar1 shaharbar1 commented May 7, 2026

Changes:

  • Add ParameterizedScaleAutoNormal — subclass of AutoNormal that accepts init_scale_fn(site_name) -> array, enabling per-parameter scale seeding instead of a single global mean
  • Refactor _build_svi_guide_init to return (init_loc_fn, init_scale_fn) with per-site sigma dict captured in closure; fullrank_advi retrieves the avg-sigma fallback via init_scale_fn("")
  • Wire ParameterizedScaleAutoNormal into _run_svi_training_loop for advi; fullrank_advi continues using AutoMultivariateNormal with scalar avg sigma
  • Add test_advi_extracted_params_match_predictive_moments — verifies extracted (mu, sigma) match Predictive sample moments within N-sigma tolerances derived from sigma_init and n_predictive_samples
  • Add test_advi_zero_lr_posterior_equals_prior — verifies that lr=0 SGD leaves posterior identical to prior; uses mu_init > 0 to guarantee init_to_value (avoiding the stochastic init_to_median cold-start path)
  • Removed single commit PR check. Shall be enforced via "Squash & Merge" from now on.

Summary by CodeRabbit

  • Chores

    • Package version bumped to 6.1.0
    • CI style check updated to validate conventional-commit labels; single-commit enforcement removed
    • Removed an obsolete rule entry
  • Refactor

    • Variational guide initialization enhanced to support per-parameter scale initialization and improved numeric handling
  • Tests

    • New JAX/NumPyro-backed tests for ADVI guide parameter consistency, NaN-loss handling, and VI stability

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 7, 2026

Review Change Stack
No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cbd61bee-caf7-46a2-b4ac-c6ef89b354ad

📥 Commits

Reviewing files that changed from the base of the PR and between 5f01aa7 and f9ea7c3.

📒 Files selected for processing (5)
  • .cursor/rules/test-cost-co.mdc
  • .github/workflows/pull_request_style_check.yml
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py
💤 Files with no reviewable changes (2)
  • .github/workflows/pull_request_style_check.yml
  • .cursor/rules/test-cost-co.mdc
✅ Files skipped from review due to trivial changes (1)
  • pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/test_model.py
  • pybandits/model.py

📝 Walkthrough

Walkthrough

Adds ParameterizedScaleAutoNormal with per-site init_scale_fn, refactors SVI init to return init_loc_fn and init_scale_fn, wires the new guide into the SVI training loop for "advi", adds Hypothesis tests validating ADVI parameter extraction and zero-LR behavior, and includes minor CI/version/rule edits.

Changes

Per-Site Scale ADVI Guide

Layer / File(s) Summary
Dependencies and Imports
pybandits/model.py, tests/test_model.py
Added context/NumPyro transform/constraint utilities to model imports; expanded test imports to include jax, jax.numpy, numpy, numpyro, Hypothesis, and Predictive.
New Guide Class
pybandits/model.py
Introduces ParameterizedScaleAutoNormal extending AutoNormal, accepting init_loc_fn and init_scale_fn(name) to create per-site loc/scale parameters; updates sigma annotation/doc to require positive scales.
SVI Guide Initialization
pybandits/model.py
_build_svi_guide_init refactored to return init_loc_fn and a per-site init_scale_fn(name) built from prior per-site sigmas with an average fallback.
Guide Configuration and Training Integration
pybandits/model.py
VI mapping updated so "advi" uses ParameterizedScaleAutoNormal; _run_svi_training_loop constructs the guide with per-site init functions for "advi" and supplies a scalar fallback for other ADVI variants; MCMC extraction clamps std to float32 tiny.
ADVI Parameter and Posterior Validation Tests
tests/test_model.py
Adds tests: test_advi_extracted_params_match_predictive_moments and test_advi_zero_lr_posterior_equals_prior, and moves NaN-loss injection to epoch-loss computation.
Maintenance
pyproject.toml, .github/workflows/pull_request_style_check.yml, .cursor/rules/test-cost-co.mdc
Bumped version to 6.1.0; replaced single-commit check with check_cc_labels step configured with hasSome and githubToken; removed a Cursor rule entry and its alwaysApply flag.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • AnastasiiaKabeshova
  • fsimond

"🐰
I tuned each mu and sigma tight,
Per-site scales hopping into sight,
Tests nibble moments, CI hums on cue,
Version bumped, small rules bid adieu."

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely summarizes the primary change: adding a ParameterizedScaleAutoNormal guide class with per-site scale initialization capability.
Description check ✅ Passed The description comprehensively covers all required template sections with specific details about changes, tests added, and workflow modifications without being vague.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/improved_guide

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@pybandits/model.py`:
- Around line 2002-2004: The init_scale_fn currently returns exact sigmas (from
site_sigmas or avg_sigma) which can be zero and cause a zero-std initialization
for ParameterizedScaleAutoNormal / the full-rank scalar fallback; change
init_scale_fn to clamp returned scales to a small positive floor (e.g. epsilon =
1e-6 or 1e-3) so both avg_sigma and site_sigmas values are max(value, epsilon)
before being returned; update references to avg_sigma, site_sigmas, and the
lambda init_scale_fn used by BaseLocationScaleArray /
ParameterizedScaleAutoNormal to use this clamped value.

In `@pyproject.toml`:
- Line 3: The version string in pyproject.toml was bumped as a patch (6.0.9) but
this change adds features and must be a MINOR bump; update the version value
from "6.0.9" to "6.1.0" in pyproject.toml so the project semantic version
reflects added functionality (e.g., after introducing
ParameterizedScaleAutoNormal and refactors).

In `@tests/test_model.py`:
- Around line 1774-1782: The mean-tolerance uses sigma_init but should use the
learned posterior scale; change the mean absolute tolerance calculation so
atol_mu is computed from site_sigma[name] (e.g. atol_mu = n_sigma_tolerance *
site_sigma[name] / np.sqrt(n_predictive_samples)) and use that per-site inside
the loop before asserting np.testing.assert_allclose(site_mu[name],
np.mean(draw, axis=0), atol=atol_mu); keep the rtol_sigma calculation as-is and
reference symbols: atol_mu, site_sigma[name], n_sigma_tolerance,
n_predictive_samples, samples, site_mu.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 99a354a2-903b-42fa-9617-f2504fc8f224

📥 Commits

Reviewing files that changed from the base of the PR and between dc5f3f3 and be1d2d2.

📒 Files selected for processing (5)
  • .cursor/rules/test-cost-co.mdc
  • .github/workflows/pull_request_style_check.yml
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py
💤 Files with no reviewable changes (2)
  • .cursor/rules/test-cost-co.mdc
  • .github/workflows/pull_request_style_check.yml

Comment thread pybandits/model.py
Comment on lines +2002 to +2004
avg_sigma = float(np.mean(np.concatenate(all_sigmas)))
init_scale_fn = lambda name: site_sigmas.get(name, avg_sigma) # noqa: E731
return init_loc_fn, init_scale_fn
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Clamp guide init scales away from zero.

BaseLocationScaleArray accepts sigma=0, and this closure now forwards those exact values into both ParameterizedScaleAutoNormal and the full-rank scalar fallback. That makes otherwise valid deterministic priors initialize the variational Normal with a zero std, which can blow up SVI before the first update.

Suggested fix
-        avg_sigma = float(np.mean(np.concatenate(all_sigmas)))
-        init_scale_fn = lambda name: site_sigmas.get(name, avg_sigma)  # noqa: E731
+        min_sigma = self._min_vi_sigma
+        avg_sigma = float(np.mean(np.concatenate([np.maximum(s, min_sigma) for s in all_sigmas])))
+        init_scale_fn = lambda name: np.maximum(site_sigmas.get(name, avg_sigma), min_sigma)  # noqa: E731
         return init_loc_fn, init_scale_fn
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pybandits/model.py` around lines 2002 - 2004, The init_scale_fn currently
returns exact sigmas (from site_sigmas or avg_sigma) which can be zero and cause
a zero-std initialization for ParameterizedScaleAutoNormal / the full-rank
scalar fallback; change init_scale_fn to clamp returned scales to a small
positive floor (e.g. epsilon = 1e-6 or 1e-3) so both avg_sigma and site_sigmas
values are max(value, epsilon) before being returned; update references to
avg_sigma, site_sigmas, and the lambda init_scale_fn used by
BaseLocationScaleArray / ParameterizedScaleAutoNormal to use this clamped value.

Comment thread pyproject.toml Outdated
Comment thread tests/test_model.py Outdated
@shaharbar1 shaharbar1 force-pushed the feature/improved_guide branch 2 times, most recently from 8ac6a08 to 4633000 Compare May 7, 2026 06:14
@shaharbar1 shaharbar1 added enhancement New feature or request labels May 7, 2026
@shaharbar1 shaharbar1 force-pushed the feature/improved_guide branch from 4633000 to 5f01aa7 Compare May 7, 2026 08:25
### Changes:
- Add ParameterizedScaleAutoNormal — subclass of AutoNormal that accepts init_scale_fn(site_name) -> array, enabling per-parameter scale seeding instead
of a single global mean
- Refactor _build_svi_guide_init to return (init_loc_fn, init_scale_fn) with per-site sigma dict captured in closure; fullrank_advi retrieves the
avg-sigma fallback via init_scale_fn("")
- Wire ParameterizedScaleAutoNormal into _run_svi_training_loop for advi; fullrank_advi continues using AutoMultivariateNormal with scalar avg sigma
- Add test_advi_extracted_params_match_predictive_moments — verifies extracted (mu, sigma) match Predictive sample moments within N-sigma tolerances derived from sigma_init and n_predictive_samples
- Add test_advi_zero_lr_posterior_equals_prior — verifies that lr=0 SGD leaves posterior identical to prior; uses mu_init > 0 to guarantee init_to_value
  (avoiding the stochastic init_to_median cold-start path)
- Removed single commit PR check. Shall be enforced via "Squash & Merge" from now on.
@shaharbar1 shaharbar1 force-pushed the feature/improved_guide branch from 5f01aa7 to f9ea7c3 Compare May 7, 2026 09:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant