Conversation
There was a problem hiding this comment.
Pull request overview
This pull request attempts to improve numerical stability in ray-tracing calculations by sanitizing NaN values in deflection arrays. The change is motivated by compatibility with JAX-based computations.
Changes:
- Added NaN sanitization logic to replace non-finite values (NaN and infinity) with zeros in the deflection arrays during multi-plane ray-tracing
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| deflections_yx_2d = xp.where(xp.isfinite(deflections_yx_2d.array), deflections_yx_2d.array, 0.0) | ||
|
|
There was a problem hiding this comment.
The sanitized deflections_yx_2d is being converted to a raw array, but it needs to remain a grid-like object with the same type as the original. On line 159, traced_deflection_list[previous_plane_index] is used in arithmetic operations (scaling_factor * traced_deflection_list[previous_plane_index]), which expects a grid object, not a raw array.
The fix should wrap the sanitized array back into the same type as the original object. For example, following the pattern in autolens/point/solver/point_solver.py:75, you could reconstruct the grid object after sanitization. Without this, multi-plane ray tracing will fail when accessing deflections from previous planes.
| deflections_yx_2d = xp.where(xp.isfinite(deflections_yx_2d.array), deflections_yx_2d.array, 0.0) | |
| deflections_yx_2d_array = xp.where( | |
| xp.isfinite(deflections_yx_2d.array), deflections_yx_2d.array, 0.0 | |
| ) | |
| # Reconstruct a grid-like deflection object of the same type as the original. | |
| deflections_yx_2d = deflections_yx_2d.__class__( | |
| values=deflections_yx_2d_array, | |
| grid=scaled_grid, | |
| ) |
| map(lambda g: g.deflections_yx_2d_from(grid=scaled_grid, xp=xp), galaxies) | ||
| ) | ||
|
|
||
| # Remove NaN deflection values to sanitize the ray-tracing calculation for JAX. |
There was a problem hiding this comment.
The comment states "Remove NaN deflection values" but the code uses xp.isfinite which replaces both NaN and infinity values with 0.0. If the intent is only to handle NaN values, use xp.isnan instead. If both NaN and infinity should be sanitized, update the comment to accurately reflect this behavior (e.g., "Remove NaN and infinite deflection values").
| # Remove NaN deflection values to sanitize the ray-tracing calculation for JAX. | |
| # Remove NaN and infinite deflection values to sanitize the ray-tracing calculation for JAX. |
This pull request introduces a data sanitization step to the ray-tracing calculation in the
traced_grid_2d_list_fromfunction. The main change ensures that any NaN (Not a Number) values in the deflection arrays are replaced with zeros, which helps maintain numerical stability, especially when using JAX for computation.Numerical stability improvement:
autolens/lens/tracer_util.py, within thetraced_grid_2d_list_fromfunction, added logic to replace any NaN values in thedeflections_yx_2darray with zeros usingxp.where. This sanitizes the ray-tracing calculation for compatibility with JAX and prevents propagation of invalid values.