-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Weights in alpha for FocalLoss #8665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Signed-off-by: ytl0623 <david89062388@gmail.com>
📝 WalkthroughWalkthroughThe PR extends FocalLoss to accept per-class alpha provided as a Sequence[float] or torch.Tensor in addition to a scalar or None. Constructors and helper functions (softmax_focal_loss, sigmoid_focal_loss) updated to accept and validate per-class alpha, broadcast it to match prediction shapes, and raise on length mismatches. Forward now derives an intermediate Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/focal_loss.py (1)
68-120: Add tests for sequence alpha feature.The new
alphaparameter now accepts sequences but has no test coverage. Add tests for:
- Sequence alpha with correct length (both softmax and sigmoid modes)
- Sequence alpha with incorrect length (should raise ValueError)
- Sequence alpha with
include_background=False(should work)- Broadcasting behavior across spatial dimensions
🧹 Nitpick comments (5)
monai/losses/focal_loss.py (5)
81-87: Clarify docstring with example.The interaction between
include_background,use_softmax, and alpha type is complex. Consider adding a brief example showing sequence alpha usage, e.g.,alpha=[0.25, 0.35, 0.4]for 3-class case.
167-167: Add stacklevel to warning.Per static analysis and best practice, specify
stacklevel=2so the warning points to the user's code, not this internal method.🔎 Proposed fix
- warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") + warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2)
222-237: Sequence length validation deferred to runtime.The check that alpha sequence length matches class count (lines 229-232) occurs inside the loss function, not at initialization or start of
forward. This means the error surfaces during training rather than at model construction. Consider validating alpha length earlier if class count can be inferred.
230-232: Simplify exception message.Per static analysis (TRY003), extract long messages into a constant or use shorter inline text.
272-274: Simplify exception message.Per static analysis (TRY003), extract long messages into a constant or use shorter inline text.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/focal_loss.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/focal_loss.py
🧬 Code graph analysis (1)
monai/losses/focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/focal_loss.py
166-166: Local variable alpha_arg is assigned to but never used
Remove assignment to unused variable alpha_arg
(F841)
167-167: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
272-274: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
🔇 Additional comments (2)
monai/losses/focal_loss.py (2)
73-73: Type hint correctly extended.The signature now accepts scalar or sequence for per-class alpha weighting.
265-281: Sequence alpha logic correct; validation deferred.The per-class alpha handling properly validates length (lines 271-274) and broadcasts (lines 276-277). However, like
softmax_focal_loss, validation occurs at runtime rather than earlier. The broadcasting and alpha_factor computation are correct.Similar to
softmax_focal_loss, consider validating alpha sequence length earlier in the lifecycle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
monai/losses/focal_loss.py (4)
161-170: Critical bug from previous review is fixed.The code now correctly passes the local variable
alpha_arginstead of the undefinedself.alpha_arg. The logic properly handlesinclude_background=Falseby nullifying scalar alpha (which assumes background weighting) while preserving sequence alpha.Add
stacklevel=2to the warning.Line 167 should include
stacklevel=2for proper warning attribution.🔎 Fix for warning stacklevel
- warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") + warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2)
210-239: Softmax focal loss correctly handles scalar and sequence alpha.The scalar path implements standard focal loss weighting (
1-alphafor background,alphafor foreground). Sequence validation ensures length matches the number of classes, and broadcasting is correct.Minor: Long error message inside ValueError.
Lines 230-232 define a long message inline. Consider extracting to a constant if this pattern recurs, but acceptable as-is.
242-283: Sigmoid focal loss correctly handles scalar and sequence alpha.The scalar path applies
alphawhentarget=1and1-alphawhentarget=0. Sequence validation and per-channel broadcasting are implemented correctly.Minor: Long error message inside ValueError.
Lines 272-274 define a long message inline. Consider extracting to a constant if this pattern recurs, but acceptable as-is.
73-73: Verify test coverage for new sequence alpha functionality.The PR adds per-class alpha support, but test checklist items are unchecked. Ensure tests cover:
- Sequence alpha with
include_background=TrueandFalse- Length validation errors
- Both softmax and sigmoid variants
- Edge cases (single class, empty sequence, etc.)
Do you want me to generate a shell script to check for existing test files that cover
FocalLoss, or open an issue to track test coverage?Also applies to: 211-211, 243-243
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/focal_loss.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/focal_loss.py
🧬 Code graph analysis (1)
monai/losses/focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/focal_loss.py
167-167: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
272-274: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: packaging
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (1)
monai/losses/focal_loss.py (1)
73-73: Type signature and documentation updated correctly.The alpha parameter now accepts sequences, and the docstring clearly specifies the length requirement relative to
include_background.Also applies to: 81-81, 85-87
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
monai/losses/focal_loss.py (3)
81-87: Clarify softmax scalar alpha convention in docstring.The docstring describes sequence alpha but doesn't explain that scalar alpha in softmax mode applies
1-alphato the background class (index 0) andalphato all other classes. This differs from sigmoid mode where alpha is applied symmetrically based on target value. Document this convention for user clarity.
223-238: Softmax alpha implementation correct.The implementation properly handles both scalar (asymmetric background weighting) and sequence (explicit per-class) alpha. Length validation and broadcasting are correct.
Optional: The docstring states alpha "should be in [0, 1]" but there's no runtime validation. Consider adding a check if strict enforcement is desired, though this may be left as a user responsibility.
267-281: Sigmoid alpha implementation correct.The implementation properly handles scalar and sequence alpha with standard focal loss formulation. Length validation and per-class broadcasting are correct.
Optional: Same as softmax—consider adding [0, 1] range validation if strict enforcement is desired.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/focal_loss.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/focal_loss.py
🧬 Code graph analysis (1)
monai/losses/focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/focal_loss.py
167-167: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
231-233: Avoid specifying long messages outside the exception class
(TRY003)
273-275: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: build-docs
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
🔇 Additional comments (2)
monai/losses/focal_loss.py (2)
73-73: Type hint correctly updated.The signature now accepts scalar, sequence, or None for per-class alpha weighting.
162-170: Forward logic correct.The
alpha_arglocal variable correctly handles the special case where scalar alpha is nulled for softmax withinclude_background=False, while preserving sequence alpha for explicit per-class control.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/losses/focal_loss.py (1)
167-167: Addstacklevel=2to warning.Set explicit
stacklevel=2so the warning points to the caller.🔎 Proposed fix
- warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") + warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2)
🧹 Nitpick comments (2)
monai/losses/focal_loss.py (2)
223-238: LGTM.Per-class alpha logic is correct: scalar creates standard background/foreground weighting; sequence validates length and broadcasts properly.
Optional: Consider validating that alpha values are in [0,1] as documented in the docstring.
267-281: LGTM.Per-class alpha logic is correct: scalar preserves original behavior; sequence validates length and applies per-class weighting with proper broadcasting.
Optional: Consider validating that alpha values are in [0,1] as documented in the docstring.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/focal_loss.py(7 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/focal_loss.py
🧬 Code graph analysis (1)
monai/losses/focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/focal_loss.py
167-167: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
231-233: Avoid specifying long messages outside the exception class
(TRY003)
273-275: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (pytype)
🔇 Additional comments (4)
monai/losses/focal_loss.py (4)
16-16: LGTM.Import correctly added for sequence type hints.
73-73: LGTM.Signature correctly extends alpha to accept per-class sequences while preserving backward compatibility.
81-87: LGTM.Docstring accurately describes new per-class alpha behavior and interaction with
include_background.
162-170: LGTM.Logic correctly handles scalar alpha with
include_background=Falseby nullifying it and warning. Thealpha_argintermediate variable prevents mutation ofself.alpha.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/losses/focal_loss.py (1)
167-167: Addstacklevel=2to warning.Still missing from previous review.
🧹 Nitpick comments (1)
monai/losses/focal_loss.py (1)
210-218: Add docstrings foralphaparameter.Both
softmax_focal_lossandsigmoid_focal_losslack documentation for thealphaparameter, including its new sequence support.Also applies to: 243-250
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/focal_loss.py(7 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/focal_loss.py
🧬 Code graph analysis (1)
monai/losses/focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/focal_loss.py
73-73: Undefined name Sequence
(F821)
74-74: Undefined name Sequence
(F821)
167-167: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
231-233: Avoid specifying long messages outside the exception class
(TRY003)
273-275: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: packaging
🔇 Additional comments (4)
monai/losses/focal_loss.py (4)
84-87: LGTM!Docstring clearly describes the new per-class alpha behavior and validation requirements.
162-170: LGTM!The
alpha_arglogic correctly preserves original alpha while handling the scalar edge case forinclude_background=False.
222-238: LGTM!Scalar vs sequence handling is clean. Validation and broadcasting are correct.
266-282: LGTM!Per-channel alpha handling correctly extends the scalar semantics to sequence form.
4211ca5 to
50cc7e9
Compare
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/losses/focal_loss.py (1)
175-175: Addstacklevel=2to warning.The warning should specify
stacklevel=2so it points to the caller rather than this internal line.🔎 Proposed fix
- warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") + warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2)
🧹 Nitpick comments (5)
monai/losses/focal_loss.py (5)
81-87: Docstring is accurate but validation is missing.Line 85 states alpha values should be in [0, 1], but the code doesn't enforce this. Consider adding validation in
__init__.Optionally, adding a brief example of per-class alpha usage (e.g.,
alpha=[0.1, 0.3, 0.6]for 3 classes) would improve clarity.🔎 Suggested validation
if isinstance(alpha, (list, tuple)): self.alpha = torch.tensor(alpha) + if (self.alpha < 0).any() or (self.alpha > 1).any(): + raise ValueError("All alpha values must be in the range [0, 1].") else: self.alpha = alpha + if isinstance(alpha, (float, int)) and not (0 <= alpha <= 1): + raise ValueError("Alpha must be in the range [0, 1].")
167-170: Alpha device handling is correct.Properly transfers tensor alpha to the input device. Minor optimization: could skip device transfer for scalar alpha, but current implementation is safe and correct.
230-246: Alpha handling logic is correct.Properly distinguishes scalar (background/foreground weighting) from sequence (per-class weighting). Validation on line 238 ensures sequence length matches number of classes.
Minor: Static analysis suggests shorter exception messages (TRY003), but this is stylistic and the descriptive message is helpful.
274-289: Per-channel alpha implementation is correct.Scalar alpha applies standard focal loss weighting, while sequence alpha provides per-channel control. Broadcasting on lines 284-286 properly handles multi-dimensional targets.
Minor: Consider shorter exception message per TRY003, though current message is clear.
73-73: Verify test coverage for per-class alpha.Ensure tests cover:
- Sequence alpha for both softmax and sigmoid modes
- Validation error when sequence length mismatches number of classes
- Behavior with
include_background=False+ sequence alpha- Edge cases: empty sequence, values outside [0,1] if validation is added
Do you want me to generate test cases for these scenarios?
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/focal_loss.py(7 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/focal_loss.py
🧬 Code graph analysis (1)
monai/losses/focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/focal_loss.py
175-175: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
239-241: Avoid specifying long messages outside the exception class
(TRY003)
281-283: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: packaging
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
🔇 Additional comments (4)
monai/losses/focal_loss.py (4)
14-16: Imports are correct.Sequence from collections.abc is properly imported and used consistently throughout the file.
73-74: Signature extension looks good.The alpha parameter now correctly accepts per-class weights via Sequence[float].
120-125: Conversion logic is correct.Properly converts sequence types to torch.Tensor while preserving scalar floats.
176-178: Loss function calls are correct.Both softmax and sigmoid paths properly forward the alpha_arg parameter.
Signed-off-by: ytl0623 <david89062388@gmail.com>
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/losses/focal_loss.py (1)
176-176: Addstacklevel=2to warning.The warning should include
stacklevel=2to point to the caller.🔎 Proposed fix
- warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") + warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2)
🧹 Nitpick comments (3)
monai/losses/focal_loss.py (3)
85-87: Clarify that range applies to each alpha element.The docstring states "The value should be in [0, 1]" but doesn't explicitly say this applies to each element when alpha is a sequence. Consider: "Each value should be in [0, 1]."
237-246: Document scalar alpha behavior for softmax.When alpha is scalar, lines 239-240 construct
[1-alpha, alpha, alpha, ...]treating alpha as the weight for all foreground classes vs.1-alphafor background. This design choice isn't documented in the function docstring or main class docstring.Consider adding a note explaining this convention for multi-class softmax with scalar alpha.
73-73: Verify test coverage for sequence alpha.The PR adds significant new functionality (per-class alpha via sequences). Ensure test coverage includes:
- Valid sequence alpha for both softmax and sigmoid paths
include_background=Falsewith sequence alpha- Length mismatch errors (lines 242-245, 288-291)
- Device and dtype handling for tensor alpha
Do you want me to help generate test cases for these scenarios?
Also applies to: 220-220, 256-256
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/focal_loss.py(7 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/focal_loss.py
🧬 Code graph analysis (1)
monai/losses/focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/focal_loss.py
176-176: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
243-245: Avoid specifying long messages outside the exception class
(TRY003)
289-291: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (1)
monai/losses/focal_loss.py (1)
279-296: Sigmoid sequence alpha implementation looks correct.The per-class alpha handling properly extends the binary focal loss formula to multi-class multi-label scenarios. Device/dtype handling and broadcasting are correct.
|
Hi @ytl0623 thanks for this, we should have the CICD issues sorted now. It looks good to me though I think we need to add tests for the class with a sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/losses/test_focal_loss.py (1)
377-416: Test adequately covers per-class alpha broadcasting.The test validates both softmax and sigmoid paths with alpha sequences, checks scalar output, and verifies error handling for length mismatches. Consider adding a test case with
include_background=Falseto ensure per-class alpha sequences work correctly after background removal, as this is a key use case mentioned in the PR objectives.Optional: Test with include_background=False
Add a test case to verify per-class alpha with background exclusion:
# Case 4: include_background=False with per-class alpha if device == devices[0]: loss_func_no_bg = FocalLoss( to_onehot_y=True, gamma=2.0, alpha=alpha_seq[1:], # Exclude background alpha use_softmax=True, include_background=False, reduction="mean" ) loss_no_bg = loss_func_no_bg(logits, target) self.assertTrue(torch.is_tensor(loss_no_bg)) self.assertEqual(loss_no_bg.ndim, 0)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/focal_loss.pytests/losses/test_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_focal_loss.pymonai/losses/focal_loss.py
🧬 Code graph analysis (1)
monai/losses/focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.10)
monai/losses/focal_loss.py
174-174: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
241-243: Avoid specifying long messages outside the exception class
(TRY003)
287-289: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (4)
monai/losses/focal_loss.py (4)
73-73: Signature and docstring correctly document per-class alpha.Type annotation and documentation clearly explain that alpha accepts sequences and specify length requirements based on
include_background.Also applies to: 81-87
117-123: Alpha conversion logic is correct.Properly handles scalar, sequence, and None cases. Converting sequences to tensor enables device transfer in helper functions.
229-250: Alpha handling correctly implements per-class weighting for softmax.Logic properly distinguishes scalar (asymmetric background weighting) from sequence (explicit per-class weights). Length validation ensures alpha matches classes after background exclusion. Broadcasting is correct.
276-296: Alpha handling correctly implements per-class weighting for sigmoid.Scalar and sequence cases properly apply class-specific alpha weights. Length validation and broadcasting logic are correct.
Signed-off-by: ytl0623 <david89062388@gmail.com>
015a894 to
1f37d0d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI Agents
In @monai/losses/focal_loss.py:
- Line 171: Modify the warnings.warn call that emits
"`include_background=False`, scalar `alpha` ignored when using softmax." to pass
stacklevel=2 (e.g., warnings.warn("...message...", stacklevel=2)) so the warning
points to the caller rather than this internal line; update the specific
warnings.warn invocation in focal_loss.py accordingly.
- Line 73: Update the test suite by adding a case in
test_alpha_sequence_broadcasting that exercises FocalLoss (or the test helper
using focal_loss) with include_background=False and a sequence alpha whose
length equals num_classes-1; construct logits and target that include the
background class, instantiate FocalLoss with include_background=False and alpha
as a list/tuple, compute the loss, and assert it matches the expected value
computed by applying the alpha sequence to the non-background channels (i.e.,
remove channel 0 before broadcasting). Use the existing test's pattern for
softmax/sigmoid branching and length-mismatch checks to validate correct
broadcasting when background is excluded.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/focal_loss.pytests/losses/test_focal_loss.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/losses/test_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/focal_loss.py
🪛 Ruff (0.14.10)
monai/losses/focal_loss.py
171-171: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-286: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
monai/losses/focal_loss.py (5)
73-73: Type hint correctly updated for per-class alpha.The signature now accepts scalar or sequence alpha as intended.
81-81: Docstring accurately reflects new sequence alpha support.The clarifications about
include_background=Falseand sequence requirements are clear.Also applies to: 85-87
117-123: Alpha normalization is clean and correct.Storing alpha as None, float, or tensor simplifies downstream handling.
227-245: Sequence alpha logic is correct; scalar alpha depends on forward() fix.The sequence path properly validates alpha length against
target.shape[1]and broadcasts correctly. The scalar path creates[1-alpha, alpha, alpha, ...], which is correct when class 0 is background, but forward() should prevent scalar alpha wheninclude_background=False(see earlier comment).
273-293: Sigmoid focal loss alpha handling is correct for both scalar and sequence.The scalar path applies per-pixel weighting (alpha for positive, 1-alpha for negative), and the sequence path applies per-channel weighting with proper length validation and broadcasting.
| to_onehot_y: bool = False, | ||
| gamma: float = 2.0, | ||
| alpha: float | None = None, | ||
| alpha: float | Sequence[float] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, locate the files
fd -n "focal_loss.py" --type fRepository: Project-MONAI/MONAI
Length of output: 293
🏁 Script executed:
# Check the FocalLoss implementation
head -100 monai/losses/focal_loss.py | cat -nRepository: Project-MONAI/MONAI
Length of output: 5742
🏁 Script executed:
# Look at the test file
fd -n "test_focal_loss.py" --type fRepository: Project-MONAI/MONAI
Length of output: 293
🏁 Script executed:
# Show test methods related to alpha
rg -n "def test.*alpha" tests/losses/test_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 114
🏁 Script executed:
# Check the alpha parameter definition and usage in focal_loss.py
rg -n -B 2 -A 10 "alpha.*Sequence\[float\]" monai/losses/focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 870
🏁 Script executed:
# Check for test assertions on sequence alpha
rg -n -B 3 -A 15 "alpha.*sequence\|test_alpha_sequence" tests/losses/test_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 45
🏁 Script executed:
# View the test_alpha_sequence_broadcasting test
sed -n '377,430p' tests/losses/test_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 1892
🏁 Script executed:
# View lines 227-245 in focal_loss.py (error handling section)
sed -n '220,250p' monai/losses/focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 1292
🏁 Script executed:
# View lines 273-293 in focal_loss.py (more error handling)
sed -n '270,300p' monai/losses/focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 1158
🏁 Script executed:
# Check the full test file to see all test methods
grep -n "def test_" tests/losses/test_focal_loss.py | head -20Repository: Project-MONAI/MONAI
Length of output: 884
🏁 Script executed:
# Check if there are tests for include_background=False with alpha
grep -B5 -A10 "include_background.*False" tests/losses/test_focal_loss.py | head -40Repository: Project-MONAI/MONAI
Length of output: 1791
🏁 Script executed:
# Check how alpha is handled with include_background in implementation
sed -n '100,180p' monai/losses/focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 3764
Add test case for include_background=False with sequence alpha.
The test test_alpha_sequence_broadcasting covers softmax and sigmoid with sequence alpha, plus length-mismatch errors. However, it lacks explicit coverage of include_background=False with sequence alpha, which per the docstring (line 81) is a valid configuration. Add a test case verifying correct behavior when include_background=False removes the background class and sequence alpha is applied to remaining classes.
🤖 Prompt for AI Agents
In @monai/losses/focal_loss.py around line 73, Update the test suite by adding a
case in test_alpha_sequence_broadcasting that exercises FocalLoss (or the test
helper using focal_loss) with include_background=False and a sequence alpha
whose length equals num_classes-1; construct logits and target that include the
background class, instantiate FocalLoss with include_background=False and alpha
as a list/tuple, compute the loss, and assert it matches the expected value
computed by applying the alpha sequence to the non-background channels (i.e.,
remove channel 0 before broadcasting). Use the existing test's pattern for
softmax/sigmoid branching and length-mismatch checks to validate correct
broadcasting when background is excluded.
| self.alpha = None | ||
| warnings.warn("`include_background=False`, `alpha` ignored when using softmax.") | ||
| if isinstance(self.alpha, (float, int)): | ||
| warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add stacklevel=2 to warning.
Without stacklevel=2, the warning points to this internal line rather than the caller's code.
- warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.")
+ warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") | |
| warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2) |
🧰 Tools
🪛 Ruff (0.14.10)
171-171: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🤖 Prompt for AI Agents
In @monai/losses/focal_loss.py around line 171, Modify the warnings.warn call
that emits "`include_background=False`, scalar `alpha` ignored when using
softmax." to pass stacklevel=2 (e.g., warnings.warn("...message...",
stacklevel=2)) so the warning points to the caller rather than this internal
line; update the specific warnings.warn invocation in focal_loss.py accordingly.
Hi @ericspod, thanks for the feedback. I have removed the redundant conversion logic and added a new test case. It should be ready for another look. Thanks in advance! |
Fixes #8601
Description
Support alpha as a list, tuple, or tensor of floats, in addition to the existing scalar support.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.