Add VI training options: num_particles, gradient clipping, and KL temperature scaling#132
Add VI training options: num_particles, gradient clipping, and KL temperature scaling#132shaharbar1 wants to merge 1 commit intodevelopfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughReplaced NumPyro optimizer construction with Optax (including optional LR schedulers and gradient clipping), refactored update-model construction ( Changes
Sequence Diagram(s)sequenceDiagram
participant BNN as BayesianNN
participant Optax as Optax
participant Scheduler as LR_Scheduler
participant Adapter as NumPyroAdapter
participant SVI as SVI_Loop
BNN->>Optax: build base optimizer (type, step_size)
Optax-->>BNN: optax optimizer
alt lr_scheduler provided
BNN->>Scheduler: create lr schedule (type, kwargs)
Scheduler-->>BNN: schedule fn / learning_rate
BNN->>Optax: combine schedule with optimizer
end
alt gradient_clip_norm provided
BNN->>Optax: chain clip_by_global_norm
Optax-->>BNN: clipped optimizer chain
end
BNN->>Adapter: noptim.optax_to_numpyro(optax_chain)
Adapter-->>BNN: numpyro-compatible optimizer
BNN->>SVI: start training (num_particles, kl_tau, effective_batch_size)
SVI->>BNN: call _create_update_model() and use _kl_scale_ctx(n_samples)
SVI->>SVI: perform VI updates with converted optimizer
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
tests/test_model.py (1)
782-800: Exercise the optimizer in SVI, not just construction.This only proves
cold_start()can assemble the transform. It won't catch failures once the scheduler/optimizer is handed tooptax_to_numpyroandSVI.update(), which is the risky part of this change—especially for the newadanpath.Minimal smoke-path extension
bnn = BayesianNeuralNetwork.cold_start( n_features=n_features, update_method="VI", update_kwargs={ + "num_steps": 1, "optimizer_type": optimizer_type, "optimizer_kwargs": optimizer_kwargs, "lr_scheduler_type": lr_scheduler_type, "lr_scheduler_kwargs": lr_scheduler_kwargs, }, ) assert bnn._obj_optimizer is not None + context = np.random.rand(4, n_features).astype(np.float32) + rewards = _make_random_rewards(4) + bnn.update(context=context, rewards=rewards)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 782 - 800, The test test_lr_scheduler_valid only verifies BayesianNeuralNetwork.cold_start assembles an optimizer/scheduler but doesn't exercise the optimizer through optax_to_numpyro and SVI.update, so failures in the optimizer integration (e.g., the new "adan" path) are missed; update the test to run a minimal SVI training step: build a tiny model/data from BayesianNeuralNetwork.cold_start, convert the optax optimizer via optax_to_numpyro (or use the existing _obj_optimizer), create an SVI instance, call SVI.update (or run one step of svi.run) with a small dummy batch, and assert the update returns a scalar loss and that bnn._obj_optimizer/state changes—this will surface integration errors during update rather than only construction.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pybandits/model.py`:
- Around line 1256-1266: The current exception handling in the lr scheduler and
optimizer construction swallows the original traceback by re-raising a new
exception class without chaining; update the two except blocks around
self._supported_lr_schedulers[lr_scheduler_type](...) and
self._supported_optimizers[optimizer_type](...) to re-raise the new exception
with the original exception as the cause (use "raise ... from e") so the
original Optax/underlying error context for invalid lr_scheduler_kwargs or
optimizer_kwargs is preserved; keep the same message text but append "from e"
when raising.
In `@pyproject.toml`:
- Line 3: The release change is breaking: you removed accepted
optimizer/scheduler values and changed the public signature of
create_update_model(), so either restore backward compatibility for one release
(re-add removed optimizer/scheduler aliases and keep the old
create_update_model(...) signature forwarding to the new implementation while
emitting a deprecation warning) or bump the package version in pyproject.toml to
a non-patch semver (e.g., 7.0.0) and update the changelog; locate the changes
around create_update_model and the optimizer/scheduler validation code to
implement the alias/deprecation forwarding, or update the version string in
pyproject.toml and note the breaking change in the release notes.
- Line 44: The pyproject.toml dependency for Optax is too low and allows 0.1.x
where optax.adan doesn't exist; update the optax requirement from "optax =
\"^0.1\"" to at least "optax = \"^0.2.6\"" (or a newer 0.2+ range) so imports of
optax.adan used in pybandits/model.py (see _supported_optimizers referencing
optax.adan) will succeed.
In `@tests/test_model.py`:
- Around line 841-880: Extend test_vi_training_options to assert the
sample_proba contract: after calling result = bnn.sample_proba(context=context),
handle both return shapes—if result is a tuple unpack it as probs, logits =
result else set probs = result and logits = None—and then assert probs.shape[0]
== n_samples, all(probs >= 0) and all(probs <= 1) (use numpy comparisons), and
if logits is not None assert np.isfinite(logits).all() so VI options cannot
produce out-of-range probabilities or non-finite logits; reference
test_vi_training_options, BayesianNeuralNetwork.sample_proba, and the result
variable when adding these assertions.
---
Nitpick comments:
In `@tests/test_model.py`:
- Around line 782-800: The test test_lr_scheduler_valid only verifies
BayesianNeuralNetwork.cold_start assembles an optimizer/scheduler but doesn't
exercise the optimizer through optax_to_numpyro and SVI.update, so failures in
the optimizer integration (e.g., the new "adan" path) are missed; update the
test to run a minimal SVI training step: build a tiny model/data from
BayesianNeuralNetwork.cold_start, convert the optax optimizer via
optax_to_numpyro (or use the existing _obj_optimizer), create an SVI instance,
call SVI.update (or run one step of svi.run) with a small dummy batch, and
assert the update returns a scalar loss and that bnn._obj_optimizer/state
changes—this will surface integration errors during update rather than only
construction.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 415dd10b-87a8-44a1-9cc5-899b6bac5d19
📒 Files selected for processing (3)
pybandits/model.pypyproject.tomltests/test_model.py
21460be to
20379d7
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
pybandits/model.py (1)
1529-1532: Edge case:n_neuronscalculation for single-layer networks.For single-layer networks (e.g.,
BayesianLogisticRegression),bnn_layer_params[:-1]is empty, resulting inn_neurons = 1(from themax(..., 1)guard). This means the KL scale factor becomeskl_tau * n_samples / 1 = kl_tau * n_samples, which may not be the intended behavior for logistic regression models.Consider documenting this behavior or adding a check to warn/skip KL scaling for single-layer models where the concept of "hidden neurons" doesn't apply.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pybandits/model.py` around lines 1529 - 1532, The current KL scaling uses n_neurons computed from self.model_params.bnn_layer_params[:-1], which is empty for single-layer models (e.g., BayesianLogisticRegression), causing unintended large scaling; add a guard that checks the number of layers (e.g., if len(self.model_params.bnn_layer_params) <= 1) and in that case skip KL scaling (return None / do not return a numpyro.handlers.scale) and emit a warning (use warnings.warn or self.logger.warning) explaining that KL scaling is skipped for single-layer models; keep the existing behavior for multi-layer networks (retain kl_tau, n_samples calculation and return numpyro.handlers.scale) and reference kl_tau, n_neurons, and model_params.bnn_layer_params in the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pybandits/model.py`:
- Around line 1097-1103: The parameter "restore_best_weights" is declared in
_vi_update_params but lacks a default/implementation, which can cause KeyError
when accessed (e.g., via self._update_kwargs.get("restore_best_weights"));
either add a sensible default entry for "restore_best_weights" in
_default_vi_kwargs (e.g., False or None) and document its behavior, or remove
"restore_best_weights" from the _vi_update_params list if the feature is not
implemented—update any code that reads self._update_kwargs to reflect the chosen
approach.
- Around line 1111-1116: The LR scheduler mapping (_supported_lr_schedulers)
currently only supplies init_value=learning_rate when the scheduler is
instantiated, but each scheduler (optax.exponential_decay,
optax.cosine_decay_schedule, optax.linear_schedule,
optax.warmup_cosine_decay_schedule) requires additional mandatory kwargs; add
explicit validation where the scheduler is constructed (the code that passes
init_value=learning_rate) to check for the required keys per scheduler type
(exponential_decay: transition_steps, decay_rate; cosine_decay_schedule:
decay_steps; linear_schedule: end_value, transition_steps;
warmup_cosine_decay_schedule: peak_value, warmup_steps, decay_steps) and raise a
clear TypeError/ValueError if any are missing (or provide sensible defaults), so
missing lr_scheduler_kwargs fail early with a helpful message.
---
Nitpick comments:
In `@pybandits/model.py`:
- Around line 1529-1532: The current KL scaling uses n_neurons computed from
self.model_params.bnn_layer_params[:-1], which is empty for single-layer models
(e.g., BayesianLogisticRegression), causing unintended large scaling; add a
guard that checks the number of layers (e.g., if
len(self.model_params.bnn_layer_params) <= 1) and in that case skip KL scaling
(return None / do not return a numpyro.handlers.scale) and emit a warning (use
warnings.warn or self.logger.warning) explaining that KL scaling is skipped for
single-layer models; keep the existing behavior for multi-layer networks (retain
kl_tau, n_samples calculation and return numpyro.handlers.scale) and reference
kl_tau, n_neurons, and model_params.bnn_layer_params in the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b1bf4df9-9e47-4111-a481-3d8134d3350d
📒 Files selected for processing (3)
pybandits/model.pypyproject.tomltests/test_model.py
✅ Files skipped from review due to trivial changes (2)
- pyproject.toml
- tests/test_model.py
| "lr_scheduler_type", | ||
| "lr_scheduler_kwargs", | ||
| "restore_best_weights", | ||
| "num_particles", | ||
| "gradient_clip_norm", | ||
| "kl_tau", | ||
| ] |
There was a problem hiding this comment.
restore_best_weights is declared in _vi_update_params but has no default value or implementation.
The parameter restore_best_weights is listed in _vi_update_params (line 1099) but is not present in _default_vi_kwargs (lines 1139-1151). This could cause KeyError if code later attempts to access it via self._update_kwargs.get("restore_best_weights") without a default, or it may indicate an incomplete feature.
Either add a default value in _default_vi_kwargs, or remove it from _vi_update_params if it's not yet implemented.
Suggested fix
_default_vi_kwargs: ClassVar[dict] = dict(
num_steps=1000,
method="advi",
optimizer_type="sgd",
optimizer_kwargs={"step_size": 0.01},
batch_size=None,
early_stopping_kwargs=None,
lr_scheduler_type=None,
lr_scheduler_kwargs=None,
num_particles=1,
gradient_clip_norm=None,
kl_tau=None,
+ restore_best_weights=False,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pybandits/model.py` around lines 1097 - 1103, The parameter
"restore_best_weights" is declared in _vi_update_params but lacks a
default/implementation, which can cause KeyError when accessed (e.g., via
self._update_kwargs.get("restore_best_weights")); either add a sensible default
entry for "restore_best_weights" in _default_vi_kwargs (e.g., False or None) and
document its behavior, or remove "restore_best_weights" from the
_vi_update_params list if the feature is not implemented—update any code that
reads self._update_kwargs to reflect the chosen approach.
20379d7 to
55b230c
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/test_model.py`:
- Around line 879-882: The assertions iterate over result from
bnn.sample_proba(context) but sample_proba returns List[Tuple[float, float]]
(probability, weighted_sum), so the tests compare tuples to scalars; fix by
unpacking each tuple (e.g., for p, w in result) and assert len(result) ==
n_samples, assert 0 <= p <= 1 for each p, and assert np.isfinite(p) and
np.isfinite(w) for both elements (or at least ensure p is finite), using the
existing symbols result, n_samples and sample_proba to locate and update the
test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 021df0ee-364e-4092-b09a-a065022e8c8d
📒 Files selected for processing (3)
pybandits/model.pypyproject.tomltests/test_model.py
🚧 Files skipped from review as they are similar to previous changes (1)
- pyproject.toml
55b230c to
9274343
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (2)
pybandits/model.py (1)
1099-1105:⚠️ Potential issue | 🟡 Minor
restore_best_weightsis declared in_vi_update_paramsbut has no default value in_default_vi_kwargs.The parameter
restore_best_weightsis listed in_vi_update_params(line 1101) but is missing from_default_vi_kwargs(lines 1184-1196). This inconsistency could cause issues if code attempts to access this key without a default, or indicates an incomplete feature.Either add a default value in
_default_vi_kwargsor remove it from_vi_update_paramsif not yet implemented.Suggested fix
_default_vi_kwargs: ClassVar[dict] = dict( num_steps=1000, method="advi", optimizer_type="sgd", optimizer_kwargs={"step_size": 0.01}, batch_size=None, early_stopping_kwargs=None, lr_scheduler_type=None, lr_scheduler_kwargs=None, num_particles=1, gradient_clip_norm=None, kl_tau=None, + restore_best_weights=False, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pybandits/model.py` around lines 1099 - 1105, _vi_update_params currently includes "restore_best_weights" but _default_vi_kwargs lacks a default for that key; either remove "restore_best_weights" from _vi_update_params if the feature isn't implemented, or add a sensible default entry (e.g., False or None) to _default_vi_kwargs so code that reads that key won't KeyError—update the dictionary named _default_vi_kwargs to include "restore_best_weights": <default_value> (or delete the symbol from the list in _vi_update_params) and ensure the chosen default aligns with how restore_best_weights is handled elsewhere in the class.tests/test_model.py (1)
891-894:⚠️ Potential issue | 🔴 CriticalFix incorrect tuple unpacking in assertions —
sample_probareturnsList[Tuple[float, float]].The assertions at lines 893-894 iterate over
resultdirectly, butsample_probareturnsList[Tuple[probability, weighted_sum]]. The current code compares tuples to scalars, which will raiseTypeError: '<=' not supported between instances of 'int' and 'tuple'.Proposed fix
result = bnn.sample_proba(context=context) assert len(result) == n_samples - assert all(0 <= p <= 1 for p in result), f"Probabilities out of [0,1]: {result}" - assert all(np.isfinite(p) for p in result), f"Non-finite probabilities: {result}" + probs, logits = zip(*result) + assert all(0.0 <= p <= 1.0 for p in probs), f"Probabilities out of [0,1]: {probs}" + assert all(np.isfinite(logit) for logit in logits), f"Non-finite logits: {logits}"As per coding guidelines,
**/*test*.py: Always apply callable cost and probability logic testing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 891 - 894, The test is iterating over result tuples returned by sample_proba (List[Tuple[probability, weighted_sum]) causing type errors; update the assertions to extract the probability values first (e.g., probs = [p for p, _ in result] or using tuple-unpacking) then assert len(result) == n_samples and run the scalar checks on probs: all(0 <= p <= 1 for p in probs) and all(np.isfinite(p) for p in probs). Ensure you reference the result from sample_proba and use the probs list in the two assertions instead of iterating over result tuples.
🧹 Nitpick comments (1)
pybandits/model.py (1)
1108-1112:_supported_optimizersand_supported_lr_schedulersclass variables are dead code and should be removed.Both
_supported_optimizers(line 1108) and_supported_lr_schedulers(line 1113) are defined but never referenced in the codebase. The_resolve_optax_fnmethod usesgetattr(optax, name, None)for dynamic validation instead of consulting these dictionaries. Remove these unused class variables to avoid misleading future maintainers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pybandits/model.py` around lines 1108 - 1112, Remove the unused class variables _supported_optimizers and _supported_lr_schedulers from the class in pybandits/model.py: locate the dictionary definitions named _supported_optimizers and _supported_lr_schedulers and delete them, leaving the dynamic resolution logic in _resolve_optax_fn intact (it uses getattr(optax, name, None)); also search for any remaining references to these variables and remove or update them if found, then run tests/linters to ensure no stray references remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@pybandits/model.py`:
- Around line 1099-1105: _vi_update_params currently includes
"restore_best_weights" but _default_vi_kwargs lacks a default for that key;
either remove "restore_best_weights" from _vi_update_params if the feature isn't
implemented, or add a sensible default entry (e.g., False or None) to
_default_vi_kwargs so code that reads that key won't KeyError—update the
dictionary named _default_vi_kwargs to include "restore_best_weights":
<default_value> (or delete the symbol from the list in _vi_update_params) and
ensure the chosen default aligns with how restore_best_weights is handled
elsewhere in the class.
In `@tests/test_model.py`:
- Around line 891-894: The test is iterating over result tuples returned by
sample_proba (List[Tuple[probability, weighted_sum]) causing type errors; update
the assertions to extract the probability values first (e.g., probs = [p for p,
_ in result] or using tuple-unpacking) then assert len(result) == n_samples and
run the scalar checks on probs: all(0 <= p <= 1 for p in probs) and
all(np.isfinite(p) for p in probs). Ensure you reference the result from
sample_proba and use the probs list in the two assertions instead of iterating
over result tuples.
---
Nitpick comments:
In `@pybandits/model.py`:
- Around line 1108-1112: Remove the unused class variables _supported_optimizers
and _supported_lr_schedulers from the class in pybandits/model.py: locate the
dictionary definitions named _supported_optimizers and _supported_lr_schedulers
and delete them, leaving the dynamic resolution logic in _resolve_optax_fn
intact (it uses getattr(optax, name, None)); also search for any remaining
references to these variables and remove or update them if found, then run
tests/linters to ensure no stray references remain.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 480dd6ed-8e0b-40a2-bbad-49b09817b4ee
📒 Files selected for processing (3)
pybandits/model.pypyproject.tomltests/test_model.py
✅ Files skipped from review due to trivial changes (1)
- pyproject.toml
79095e4 to
8e8d065
Compare
…perature scaling ### Changes: * Add `num_particles`, `gradient_clip_norm`, and `kl_tau` to VI `update_kwargs` - `num_particles` is passed to `TraceMeanField_ELBO`/`Trace_ELBO` for multi-particle gradient estimates - `gradient_clip_norm` chains `optax.clip_by_global_norm` before the base optimizer - `kl_tau` scales prior log-prob by `tau * N_data / N_neurons` via `numpyro.handlers.scale` (Huix et al. 2022) * Switch all optimizers from numpyro-native to optax via `optax_to_numpyro`, removing the dual-path logic * Remove `clipped_adam`, `momentum`, and `reduce_on_plateau` scheduler * Remove `batch_size` parameter from `create_update_model`; read from `_update_kwargs` instead * Add `optax` to `pyproject.toml` dependencies
8e8d065 to
4cf1589
Compare
Changes:
num_particles,gradient_clip_norm, andkl_tauto VIupdate_kwargsnum_particlesis passed toTraceMeanField_ELBO/Trace_ELBOfor multi-particle gradient estimatesgradient_clip_normchainsoptax.clip_by_global_normbefore the base optimizerkl_tauscales prior log-prob bytau * N_data / N_neuronsvianumpyro.handlers.scale(Huix et al. 2022)optax_to_numpyro, removing the dual-path logicclipped_adam,momentum, andreduce_on_plateauschedulerbatch_sizeparameter fromcreate_update_model; read from_update_kwargsinsteadoptaxtopyproject.tomldependenciesSummary by CodeRabbit
New Features
Breaking Changes
Tests
Chores