Skip to content

Add VI training options: num_particles, gradient clipping, and KL temperature scaling#132

Open
shaharbar1 wants to merge 1 commit intodevelopfrom
feature/vi_improvements
Open

Add VI training options: num_particles, gradient clipping, and KL temperature scaling#132
shaharbar1 wants to merge 1 commit intodevelopfrom
feature/vi_improvements

Conversation

@shaharbar1
Copy link
Copy Markdown
Collaborator

@shaharbar1 shaharbar1 commented Apr 5, 2026

Changes:

  • Add num_particles, gradient_clip_norm, and kl_tau to VI update_kwargs
    • num_particles is passed to TraceMeanField_ELBO/Trace_ELBO for multi-particle gradient estimates
    • gradient_clip_norm chains optax.clip_by_global_norm before the base optimizer
    • kl_tau scales prior log-prob by tau * N_data / N_neurons via numpyro.handlers.scale (Huix et al. 2022)
  • Switch all optimizers from numpyro-native to optax via optax_to_numpyro, removing the dual-path logic
  • Remove clipped_adam, momentum, and reduce_on_plateau scheduler
  • Remove batch_size parameter from create_update_model; read from _update_kwargs instead
  • Add optax to pyproject.toml dependencies

Summary by CodeRabbit

  • New Features

    • Optax optimizer support with selectable LR schedulers, optional gradient clipping, KL-temperature scaling, and VI options like num_particles.
  • Breaking Changes

    • Batch size is no longer passed directly to the model-update call — configure batch size via update options.
  • Tests

    • Added/updated tests for LR schedulers, VI training options, and categorical-feature updates.
  • Chores

    • Bumped package version to 6.1.0 and added Optax as a runtime dependency.

@shaharbar1 shaharbar1 requested a review from ronshiff1 April 5, 2026 17:14
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 5, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Replaced NumPyro optimizer construction with Optax (including optional LR schedulers and gradient clipping), refactored update-model construction (create_update_model_create_update_model) and added _kl_scale_ctx; adjusted VI/MCMC flows and tests; bumped package version and added optax dependency.

Changes

Cohort / File(s) Summary
Optimizer & training plumbing
pybandits/model.py
Replaced NumPyro optimizer mapping with Optax-based builders; added _supported_lr_schedulers, _optax_return_types, and _resolve_optax_fn(...); _get_obj_optimizer() now constructs optax chains, supports LR schedulers, gradient clipping, and converts via noptim.optax_to_numpyro. VI loss now accepts num_particles.
Model structure & KL scaling
pybandits/model.py
Removed public create_update_model(...); added _create_update_model() and _kl_scale_ctx(n_samples) returning a scaling context or nullcontext(); applied KL scaling to sampling sites and guide wrapping during VI when kl_tau is set.
SVI/MCMC control flow
pybandits/model.py
Refactored _run_svi_training_loop() to always use _create_update_model(), compute effective_batch_size, pass num_particles into VI loss factory, compute epoch_end_loss once per epoch for observability/early stopping, and return the raw/unwrapped guide. Updated MCMC model construction to call _create_update_model().
Tests
tests/test_model.py
Updated tests to use _create_update_model() and helper _make_random_rewards(...); refactored categorical BNN helpers to accept update_kwargs; added tests for Optax kind/type consistency, valid/invalid LR schedulers, and property-based VI training (gradient clipping, kl_tau, schedulers). Adjusted Hypothesis strategies and removed obsolete decorators.
Packaging
pyproject.toml
Bumped package version to 6.1.0 and added runtime dependency optax = "^0.1".

Sequence Diagram(s)

sequenceDiagram
    participant BNN as BayesianNN
    participant Optax as Optax
    participant Scheduler as LR_Scheduler
    participant Adapter as NumPyroAdapter
    participant SVI as SVI_Loop

    BNN->>Optax: build base optimizer (type, step_size)
    Optax-->>BNN: optax optimizer
    alt lr_scheduler provided
        BNN->>Scheduler: create lr schedule (type, kwargs)
        Scheduler-->>BNN: schedule fn / learning_rate
        BNN->>Optax: combine schedule with optimizer
    end
    alt gradient_clip_norm provided
        BNN->>Optax: chain clip_by_global_norm
        Optax-->>BNN: clipped optimizer chain
    end
    BNN->>Adapter: noptim.optax_to_numpyro(optax_chain)
    Adapter-->>BNN: numpyro-compatible optimizer
    BNN->>SVI: start training (num_particles, kl_tau, effective_batch_size)
    SVI->>BNN: call _create_update_model() and use _kl_scale_ctx(n_samples)
    SVI->>SVI: perform VI updates with converted optimizer
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 I braided Optax threads through my burrowed plot,
schedulers hum softly, gradients tidy the lot,
KL warms the night where particles play,
guides stretch their ears and find the way,
epochs nibble carrots, models hop to day.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 64.52% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: adding VI training options (num_particles, gradient clipping, KL temperature scaling) which are the primary features added in this PR.
Description check ✅ Passed The description follows the template structure with a Changes section detailing all significant modifications including new VI parameters, optimizer switch, removed features, and dependency updates.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/vi_improvements

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@shaharbar1 shaharbar1 added the enhancement New feature or request label Apr 5, 2026
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
tests/test_model.py (1)

782-800: Exercise the optimizer in SVI, not just construction.

This only proves cold_start() can assemble the transform. It won't catch failures once the scheduler/optimizer is handed to optax_to_numpyro and SVI.update(), which is the risky part of this change—especially for the new adan path.

Minimal smoke-path extension
     bnn = BayesianNeuralNetwork.cold_start(
         n_features=n_features,
         update_method="VI",
         update_kwargs={
+            "num_steps": 1,
             "optimizer_type": optimizer_type,
             "optimizer_kwargs": optimizer_kwargs,
             "lr_scheduler_type": lr_scheduler_type,
             "lr_scheduler_kwargs": lr_scheduler_kwargs,
         },
     )
     assert bnn._obj_optimizer is not None
+    context = np.random.rand(4, n_features).astype(np.float32)
+    rewards = _make_random_rewards(4)
+    bnn.update(context=context, rewards=rewards)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 782 - 800, The test test_lr_scheduler_valid
only verifies BayesianNeuralNetwork.cold_start assembles an optimizer/scheduler
but doesn't exercise the optimizer through optax_to_numpyro and SVI.update, so
failures in the optimizer integration (e.g., the new "adan" path) are missed;
update the test to run a minimal SVI training step: build a tiny model/data from
BayesianNeuralNetwork.cold_start, convert the optax optimizer via
optax_to_numpyro (or use the existing _obj_optimizer), create an SVI instance,
call SVI.update (or run one step of svi.run) with a small dummy batch, and
assert the update returns a scalar loss and that bnn._obj_optimizer/state
changes—this will surface integration errors during update rather than only
construction.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pybandits/model.py`:
- Around line 1256-1266: The current exception handling in the lr scheduler and
optimizer construction swallows the original traceback by re-raising a new
exception class without chaining; update the two except blocks around
self._supported_lr_schedulers[lr_scheduler_type](...) and
self._supported_optimizers[optimizer_type](...) to re-raise the new exception
with the original exception as the cause (use "raise ... from e") so the
original Optax/underlying error context for invalid lr_scheduler_kwargs or
optimizer_kwargs is preserved; keep the same message text but append "from e"
when raising.

In `@pyproject.toml`:
- Line 3: The release change is breaking: you removed accepted
optimizer/scheduler values and changed the public signature of
create_update_model(), so either restore backward compatibility for one release
(re-add removed optimizer/scheduler aliases and keep the old
create_update_model(...) signature forwarding to the new implementation while
emitting a deprecation warning) or bump the package version in pyproject.toml to
a non-patch semver (e.g., 7.0.0) and update the changelog; locate the changes
around create_update_model and the optimizer/scheduler validation code to
implement the alias/deprecation forwarding, or update the version string in
pyproject.toml and note the breaking change in the release notes.
- Line 44: The pyproject.toml dependency for Optax is too low and allows 0.1.x
where optax.adan doesn't exist; update the optax requirement from "optax =
\"^0.1\"" to at least "optax = \"^0.2.6\"" (or a newer 0.2+ range) so imports of
optax.adan used in pybandits/model.py (see _supported_optimizers referencing
optax.adan) will succeed.

In `@tests/test_model.py`:
- Around line 841-880: Extend test_vi_training_options to assert the
sample_proba contract: after calling result = bnn.sample_proba(context=context),
handle both return shapes—if result is a tuple unpack it as probs, logits =
result else set probs = result and logits = None—and then assert probs.shape[0]
== n_samples, all(probs >= 0) and all(probs <= 1) (use numpy comparisons), and
if logits is not None assert np.isfinite(logits).all() so VI options cannot
produce out-of-range probabilities or non-finite logits; reference
test_vi_training_options, BayesianNeuralNetwork.sample_proba, and the result
variable when adding these assertions.

---

Nitpick comments:
In `@tests/test_model.py`:
- Around line 782-800: The test test_lr_scheduler_valid only verifies
BayesianNeuralNetwork.cold_start assembles an optimizer/scheduler but doesn't
exercise the optimizer through optax_to_numpyro and SVI.update, so failures in
the optimizer integration (e.g., the new "adan" path) are missed; update the
test to run a minimal SVI training step: build a tiny model/data from
BayesianNeuralNetwork.cold_start, convert the optax optimizer via
optax_to_numpyro (or use the existing _obj_optimizer), create an SVI instance,
call SVI.update (or run one step of svi.run) with a small dummy batch, and
assert the update returns a scalar loss and that bnn._obj_optimizer/state
changes—this will surface integration errors during update rather than only
construction.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 415dd10b-87a8-44a1-9cc5-899b6bac5d19

📥 Commits

Reviewing files that changed from the base of the PR and between 25d502f and 21460be.

📒 Files selected for processing (3)
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py

Comment thread pybandits/model.py Outdated
Comment thread pyproject.toml Outdated
Comment thread pyproject.toml
Comment thread tests/test_model.py
@shaharbar1 shaharbar1 force-pushed the feature/vi_improvements branch from 21460be to 20379d7 Compare April 5, 2026 20:09
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
pybandits/model.py (1)

1529-1532: Edge case: n_neurons calculation for single-layer networks.

For single-layer networks (e.g., BayesianLogisticRegression), bnn_layer_params[:-1] is empty, resulting in n_neurons = 1 (from the max(..., 1) guard). This means the KL scale factor becomes kl_tau * n_samples / 1 = kl_tau * n_samples, which may not be the intended behavior for logistic regression models.

Consider documenting this behavior or adding a check to warn/skip KL scaling for single-layer models where the concept of "hidden neurons" doesn't apply.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 1529 - 1532, The current KL scaling uses
n_neurons computed from self.model_params.bnn_layer_params[:-1], which is empty
for single-layer models (e.g., BayesianLogisticRegression), causing unintended
large scaling; add a guard that checks the number of layers (e.g., if
len(self.model_params.bnn_layer_params) <= 1) and in that case skip KL scaling
(return None / do not return a numpyro.handlers.scale) and emit a warning (use
warnings.warn or self.logger.warning) explaining that KL scaling is skipped for
single-layer models; keep the existing behavior for multi-layer networks (retain
kl_tau, n_samples calculation and return numpyro.handlers.scale) and reference
kl_tau, n_neurons, and model_params.bnn_layer_params in the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pybandits/model.py`:
- Around line 1097-1103: The parameter "restore_best_weights" is declared in
_vi_update_params but lacks a default/implementation, which can cause KeyError
when accessed (e.g., via self._update_kwargs.get("restore_best_weights"));
either add a sensible default entry for "restore_best_weights" in
_default_vi_kwargs (e.g., False or None) and document its behavior, or remove
"restore_best_weights" from the _vi_update_params list if the feature is not
implemented—update any code that reads self._update_kwargs to reflect the chosen
approach.
- Around line 1111-1116: The LR scheduler mapping (_supported_lr_schedulers)
currently only supplies init_value=learning_rate when the scheduler is
instantiated, but each scheduler (optax.exponential_decay,
optax.cosine_decay_schedule, optax.linear_schedule,
optax.warmup_cosine_decay_schedule) requires additional mandatory kwargs; add
explicit validation where the scheduler is constructed (the code that passes
init_value=learning_rate) to check for the required keys per scheduler type
(exponential_decay: transition_steps, decay_rate; cosine_decay_schedule:
decay_steps; linear_schedule: end_value, transition_steps;
warmup_cosine_decay_schedule: peak_value, warmup_steps, decay_steps) and raise a
clear TypeError/ValueError if any are missing (or provide sensible defaults), so
missing lr_scheduler_kwargs fail early with a helpful message.

---

Nitpick comments:
In `@pybandits/model.py`:
- Around line 1529-1532: The current KL scaling uses n_neurons computed from
self.model_params.bnn_layer_params[:-1], which is empty for single-layer models
(e.g., BayesianLogisticRegression), causing unintended large scaling; add a
guard that checks the number of layers (e.g., if
len(self.model_params.bnn_layer_params) <= 1) and in that case skip KL scaling
(return None / do not return a numpyro.handlers.scale) and emit a warning (use
warnings.warn or self.logger.warning) explaining that KL scaling is skipped for
single-layer models; keep the existing behavior for multi-layer networks (retain
kl_tau, n_samples calculation and return numpyro.handlers.scale) and reference
kl_tau, n_neurons, and model_params.bnn_layer_params in the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b1bf4df9-9e47-4111-a481-3d8134d3350d

📥 Commits

Reviewing files that changed from the base of the PR and between 21460be and 20379d7.

📒 Files selected for processing (3)
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py
✅ Files skipped from review due to trivial changes (2)
  • pyproject.toml
  • tests/test_model.py

Comment thread pybandits/model.py
Comment on lines +1097 to 1103
"lr_scheduler_type",
"lr_scheduler_kwargs",
"restore_best_weights",
"num_particles",
"gradient_clip_norm",
"kl_tau",
]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

restore_best_weights is declared in _vi_update_params but has no default value or implementation.

The parameter restore_best_weights is listed in _vi_update_params (line 1099) but is not present in _default_vi_kwargs (lines 1139-1151). This could cause KeyError if code later attempts to access it via self._update_kwargs.get("restore_best_weights") without a default, or it may indicate an incomplete feature.

Either add a default value in _default_vi_kwargs, or remove it from _vi_update_params if it's not yet implemented.

Suggested fix
     _default_vi_kwargs: ClassVar[dict] = dict(
         num_steps=1000,
         method="advi",
         optimizer_type="sgd",
         optimizer_kwargs={"step_size": 0.01},
         batch_size=None,
         early_stopping_kwargs=None,
         lr_scheduler_type=None,
         lr_scheduler_kwargs=None,
         num_particles=1,
         gradient_clip_norm=None,
         kl_tau=None,
+        restore_best_weights=False,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 1097 - 1103, The parameter
"restore_best_weights" is declared in _vi_update_params but lacks a
default/implementation, which can cause KeyError when accessed (e.g., via
self._update_kwargs.get("restore_best_weights")); either add a sensible default
entry for "restore_best_weights" in _default_vi_kwargs (e.g., False or None) and
document its behavior, or remove "restore_best_weights" from the
_vi_update_params list if the feature is not implemented—update any code that
reads self._update_kwargs to reflect the chosen approach.

Comment thread pybandits/model.py Outdated
@shaharbar1 shaharbar1 force-pushed the feature/vi_improvements branch from 20379d7 to 55b230c Compare April 5, 2026 20:14
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/test_model.py`:
- Around line 879-882: The assertions iterate over result from
bnn.sample_proba(context) but sample_proba returns List[Tuple[float, float]]
(probability, weighted_sum), so the tests compare tuples to scalars; fix by
unpacking each tuple (e.g., for p, w in result) and assert len(result) ==
n_samples, assert 0 <= p <= 1 for each p, and assert np.isfinite(p) and
np.isfinite(w) for both elements (or at least ensure p is finite), using the
existing symbols result, n_samples and sample_proba to locate and update the
test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 021df0ee-364e-4092-b09a-a065022e8c8d

📥 Commits

Reviewing files that changed from the base of the PR and between 20379d7 and 55b230c.

📒 Files selected for processing (3)
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • pyproject.toml

Comment thread tests/test_model.py Outdated
@shaharbar1 shaharbar1 force-pushed the feature/vi_improvements branch from 55b230c to 9274343 Compare April 6, 2026 06:56
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
pybandits/model.py (1)

1099-1105: ⚠️ Potential issue | 🟡 Minor

restore_best_weights is declared in _vi_update_params but has no default value in _default_vi_kwargs.

The parameter restore_best_weights is listed in _vi_update_params (line 1101) but is missing from _default_vi_kwargs (lines 1184-1196). This inconsistency could cause issues if code attempts to access this key without a default, or indicates an incomplete feature.

Either add a default value in _default_vi_kwargs or remove it from _vi_update_params if not yet implemented.

Suggested fix
     _default_vi_kwargs: ClassVar[dict] = dict(
         num_steps=1000,
         method="advi",
         optimizer_type="sgd",
         optimizer_kwargs={"step_size": 0.01},
         batch_size=None,
         early_stopping_kwargs=None,
         lr_scheduler_type=None,
         lr_scheduler_kwargs=None,
         num_particles=1,
         gradient_clip_norm=None,
         kl_tau=None,
+        restore_best_weights=False,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 1099 - 1105, _vi_update_params currently
includes "restore_best_weights" but _default_vi_kwargs lacks a default for that
key; either remove "restore_best_weights" from _vi_update_params if the feature
isn't implemented, or add a sensible default entry (e.g., False or None) to
_default_vi_kwargs so code that reads that key won't KeyError—update the
dictionary named _default_vi_kwargs to include "restore_best_weights":
<default_value> (or delete the symbol from the list in _vi_update_params) and
ensure the chosen default aligns with how restore_best_weights is handled
elsewhere in the class.
tests/test_model.py (1)

891-894: ⚠️ Potential issue | 🔴 Critical

Fix incorrect tuple unpacking in assertions — sample_proba returns List[Tuple[float, float]].

The assertions at lines 893-894 iterate over result directly, but sample_proba returns List[Tuple[probability, weighted_sum]]. The current code compares tuples to scalars, which will raise TypeError: '<=' not supported between instances of 'int' and 'tuple'.

Proposed fix
     result = bnn.sample_proba(context=context)
     assert len(result) == n_samples
-    assert all(0 <= p <= 1 for p in result), f"Probabilities out of [0,1]: {result}"
-    assert all(np.isfinite(p) for p in result), f"Non-finite probabilities: {result}"
+    probs, logits = zip(*result)
+    assert all(0.0 <= p <= 1.0 for p in probs), f"Probabilities out of [0,1]: {probs}"
+    assert all(np.isfinite(logit) for logit in logits), f"Non-finite logits: {logits}"

As per coding guidelines, **/*test*.py: Always apply callable cost and probability logic testing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 891 - 894, The test is iterating over
result tuples returned by sample_proba (List[Tuple[probability, weighted_sum])
causing type errors; update the assertions to extract the probability values
first (e.g., probs = [p for p, _ in result] or using tuple-unpacking) then
assert len(result) == n_samples and run the scalar checks on probs: all(0 <= p
<= 1 for p in probs) and all(np.isfinite(p) for p in probs). Ensure you
reference the result from sample_proba and use the probs list in the two
assertions instead of iterating over result tuples.
🧹 Nitpick comments (1)
pybandits/model.py (1)

1108-1112: _supported_optimizers and _supported_lr_schedulers class variables are dead code and should be removed.

Both _supported_optimizers (line 1108) and _supported_lr_schedulers (line 1113) are defined but never referenced in the codebase. The _resolve_optax_fn method uses getattr(optax, name, None) for dynamic validation instead of consulting these dictionaries. Remove these unused class variables to avoid misleading future maintainers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 1108 - 1112, Remove the unused class
variables _supported_optimizers and _supported_lr_schedulers from the class in
pybandits/model.py: locate the dictionary definitions named
_supported_optimizers and _supported_lr_schedulers and delete them, leaving the
dynamic resolution logic in _resolve_optax_fn intact (it uses getattr(optax,
name, None)); also search for any remaining references to these variables and
remove or update them if found, then run tests/linters to ensure no stray
references remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@pybandits/model.py`:
- Around line 1099-1105: _vi_update_params currently includes
"restore_best_weights" but _default_vi_kwargs lacks a default for that key;
either remove "restore_best_weights" from _vi_update_params if the feature isn't
implemented, or add a sensible default entry (e.g., False or None) to
_default_vi_kwargs so code that reads that key won't KeyError—update the
dictionary named _default_vi_kwargs to include "restore_best_weights":
<default_value> (or delete the symbol from the list in _vi_update_params) and
ensure the chosen default aligns with how restore_best_weights is handled
elsewhere in the class.

In `@tests/test_model.py`:
- Around line 891-894: The test is iterating over result tuples returned by
sample_proba (List[Tuple[probability, weighted_sum]) causing type errors; update
the assertions to extract the probability values first (e.g., probs = [p for p,
_ in result] or using tuple-unpacking) then assert len(result) == n_samples and
run the scalar checks on probs: all(0 <= p <= 1 for p in probs) and
all(np.isfinite(p) for p in probs). Ensure you reference the result from
sample_proba and use the probs list in the two assertions instead of iterating
over result tuples.

---

Nitpick comments:
In `@pybandits/model.py`:
- Around line 1108-1112: Remove the unused class variables _supported_optimizers
and _supported_lr_schedulers from the class in pybandits/model.py: locate the
dictionary definitions named _supported_optimizers and _supported_lr_schedulers
and delete them, leaving the dynamic resolution logic in _resolve_optax_fn
intact (it uses getattr(optax, name, None)); also search for any remaining
references to these variables and remove or update them if found, then run
tests/linters to ensure no stray references remain.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 480dd6ed-8e0b-40a2-bbad-49b09817b4ee

📥 Commits

Reviewing files that changed from the base of the PR and between 55b230c and 9274343.

📒 Files selected for processing (3)
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py
✅ Files skipped from review due to trivial changes (1)
  • pyproject.toml

@shaharbar1 shaharbar1 force-pushed the feature/vi_improvements branch 4 times, most recently from 79095e4 to 8e8d065 Compare April 6, 2026 09:17
…perature scaling

### Changes:
* Add `num_particles`, `gradient_clip_norm`, and `kl_tau` to VI `update_kwargs`
  - `num_particles` is passed to `TraceMeanField_ELBO`/`Trace_ELBO` for multi-particle gradient estimates
  - `gradient_clip_norm` chains `optax.clip_by_global_norm` before the base optimizer
  - `kl_tau` scales prior log-prob by `tau * N_data / N_neurons` via `numpyro.handlers.scale` (Huix et al. 2022)
* Switch all optimizers from numpyro-native to optax via `optax_to_numpyro`, removing the dual-path logic
* Remove `clipped_adam`, `momentum`, and `reduce_on_plateau` scheduler
* Remove `batch_size` parameter from `create_update_model`; read from `_update_kwargs` instead
* Add `optax` to `pyproject.toml` dependencies
@shaharbar1 shaharbar1 force-pushed the feature/vi_improvements branch from 8e8d065 to 4cf1589 Compare April 6, 2026 11:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant