-
Notifications
You must be signed in to change notification settings - Fork 582
fix(pt): pairtab #5119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
fix(pt): pairtab #5119
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes a potential NaN gradient issue in the PyTorch implementation of the pairtab atomic model by replacing the standard torch.linalg.norm computation with a safe norm that uses epsilon clamping.
Key Changes:
- Modified
_get_pairwise_distmethod to usetorch.sqrt(torch.sum(diff * diff, dim=-1, keepdim=True).clamp(min=1e-14))instead oftorch.linalg.norm - Added comprehensive documentation in the Notes section explaining when and why this safe norm is needed
- Added inline comments explaining the epsilon value choice and its purpose
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughModified Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
📜 Recent review detailsConfiguration used: Repository UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)
409-416: LGTM! Safe norm implementation correctly prevents NaN gradients.The implementation correctly computes the Euclidean norm with numerical stability:
- The epsilon value (1e-14) is well-chosen: small enough to not affect physical distances (atomic distances typically > 0.1 Å, squared > 0.01) yet large enough to prevent gradient issues
- The clamp on the squared sum (before sqrt) is the right approach to prevent unbounded gradients
- Inline comments clearly explain the rationale
Optional: Consider defining epsilon as a named constant
For improved maintainability, you could define the epsilon as a class-level constant:
class PairTabAtomicModel(BaseAtomicModel): # Epsilon for safe norm computation to prevent NaN gradients _SAFE_NORM_EPSILON = 1e-14 ...Then use it in the computation:
pairwise_rr = torch.sqrt( - torch.sum(diff * diff, dim=-1, keepdim=True).clamp(min=1e-14) + torch.sum(diff * diff, dim=-1, keepdim=True).clamp(min=self._SAFE_NORM_EPSILON) ).squeeze(-1)This makes it easier to adjust the epsilon value consistently if needed in the future. However, this is a minor improvement and can be deferred.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (30)
- GitHub Check: Agent
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
🔇 Additional comments (1)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)
395-402: LGTM! Clear documentation of the numerical stability fix.The Notes section clearly explains the rationale for the safe norm computation and when zero difference vectors can occur. This will help future maintainers understand why the epsilon clamp is necessary.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #5119 +/- ##
==========================================
- Coverage 82.15% 82.15% -0.01%
==========================================
Files 709 709
Lines 72468 72470 +2
Branches 3616 3616
==========================================
+ Hits 59535 59536 +1
- Misses 11769 11771 +2
+ Partials 1164 1163 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: OutisLi <137472077+OutisLi@users.noreply.github.com>
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.