Conversation
There was a problem hiding this comment.
Pull request overview
This pull request refactors the interferometry linear algebra "w_tilde" formalism to "sparse operator" terminology and extends it to imaging datasets, enabling efficient GPU-based lens modeling via JAX. The changes rename classes, methods, and internal logic throughout the codebase while adding new FFT-based sparse operator support for imaging data.
Changes:
- Renamed all w_tilde references to sparse_operator across classes, methods, and parameters (e.g.,
WTildeImaging→ImagingSparseOperator,apply_w_tilde()→apply_sparse_operator()) - Added JAX-based FFT sparse operator implementation for imaging datasets via new
ImagingSparseOperatorclass - Removed legacy w_tilde modules and deprecated settings (
use_w_tilde_numpy,use_source_loop)
Reviewed changes
Copilot reviewed 35 out of 36 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| test_autoarray/inversion/inversion/test_settings_dict.py | Removed deprecated settings from test expectations |
| test_autoarray/inversion/inversion/test_inversion_util.py | Renamed test and updated function calls to use new sparse operator API |
| test_autoarray/inversion/inversion/test_factory.py | Updated test names and assertions from w_tilde to sparse_operator |
| test_autoarray/inversion/inversion/test_abstract.py | Updated test names and method calls to use sparse operator terminology |
| test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py | Refactored interferometer tests to use new sparse operator classes and removed deprecated test |
| test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py | Updated imaging tests with new API but contains critical bugs and debug code |
| test_autoarray/inversion/inversion/imaging/test_imaging.py | Updated imports to use new sparse operator classes |
| test_autoarray/dataset/interferometer/test_dataset.py | Removed test for deprecated curvature preload metadata |
| test_autoarray/dataset/imaging/test_dataset.py | Removed assertions checking w_tilde attributes |
| autoarray/util/init.py | Fixed import path for inversion_imaging_numba module |
| autoarray/mock.py | Removed deprecated MockInversionImagingWTilde import |
| autoarray/mask/mask_2d.py | Added fft_index_for_masked_pixel property for sparse operator support |
| autoarray/inversion/pixelization/mappers/mapper_util.py | Added sparse_triplets_from function for COO triplet construction |
| autoarray/inversion/pixelization/mappers/abstract.py | Added sparse_triplets_data and sparse_triplets_curvature properties |
| autoarray/inversion/pixelization/border_relocator.py | Renamed use_w_tilde to use_sparse_operator parameter |
| autoarray/inversion/mock/mock_inversion_imaging.py | Removed MockInversionImagingWTilde and related mock classes |
| autoarray/inversion/inversion/settings.py | Removed deprecated use_w_tilde_numpy and use_source_loop settings |
| autoarray/inversion/inversion/inversion_util.py | Renamed functions from w_tilde to psf_weighted_noise terminology |
| autoarray/inversion/inversion/interferometer/sparse.py | Renamed class from InversionInterferometerWTilde to InversionInterferometerSparse |
| autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py | Renamed utility functions and dataclass to sparse operator terminology |
| autoarray/inversion/inversion/imaging_numba/sparse.py | Renamed class to InversionImagingSparseNumba with critical bugs in xp checks |
| autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py | Renamed functions and added SparseLinAlgImagingNumba class |
| autoarray/inversion/inversion/imaging/sparse.py | New JAX-based InversionImagingSparse implementation |
| autoarray/inversion/inversion/imaging/inversion_imaging_util.py | New ImagingSparseOperator dataclass with JAX FFT operations |
| autoarray/inversion/inversion/imaging/abstract.py | Updated docstrings to reference sparse operator formalism |
| autoarray/inversion/inversion/factory.py | Updated factory to create appropriate sparse operator inversion types |
| autoarray/inversion/inversion/dataset_interface.py | Renamed w_tilde parameter to sparse_operator |
| autoarray/inversion/inversion/abstract.py | Updated docstring terminology |
| autoarray/dataset/interferometer/w_tilde.py | Deleted legacy WTildeInterferometer implementation |
| autoarray/dataset/interferometer/dataset.py | Renamed apply_w_tilde to apply_sparse_operator method |
| autoarray/dataset/imaging/w_tilde.py | Deleted legacy WTildeImaging implementation |
| autoarray/dataset/imaging/dataset.py | Added apply_sparse_operator methods for both JAX and CPU implementations |
| autoarray/dataset/grids.py | Renamed use_w_tilde to use_sparse_operator parameter |
| autoarray/dataset/abstract/w_tilde.py | Deleted legacy AbstractWTilde base class |
| autoarray/init.py | Updated exports to use new sparse operator classes |
Comments suppressed due to low confidence (11)
autoarray/inversion/inversion/imaging_numba/sparse.py:218
- The condition
if np is npwill always evaluate to True, making the else branch unreachable. This should likely be checking whetherself._xp is npto determine if NumPy or JAX is being used, consistent with similar checks elsewhere in the codebase.
autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py:874 - Typo in docstring: "efficienctly" should be "efficiently".
autoarray/inversion/inversion/imaging_numba/sparse.py:496 - The condition
if np is npwill always evaluate to True, making the else branch unreachable. This should likely be checking whetherself._xp is npto determine if NumPy or JAX is being used.
autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py:920 - The property decorator is defined but the function returns a call to
psf_precision_operator_fromwhich suggests this should be a regular method, not a property. Additionally, the function attempts to accessself.psf_precision_operatorwhich would cause infinite recursion if this were actually called as a property. This appears to be incorrect - either the decorator should be removed or the implementation needs to be fixed.
autoarray/inversion/inversion/imaging_numba/sparse.py:307 - The condition
if np is npwill always evaluate to True, making the else branch unreachable. This should likely be checking whetherself._xp is npto determine if NumPy or JAX is being used.
autoarray/inversion/inversion/imaging_numba/sparse.py:460 - The condition
if np is npwill always evaluate to True, making the else branch unreachable. This should likely be checking whetherself._xp is npto determine if NumPy or JAX is being used.
autoarray/inversion/inversion/imaging_numba/sparse.py:215 - Comparison of identical values; use cmath.isnan() if testing for not-a-number.
autoarray/inversion/inversion/imaging_numba/sparse.py:304 - Comparison of identical values; use cmath.isnan() if testing for not-a-number.
autoarray/inversion/inversion/imaging_numba/sparse.py:449 - Comparison of identical values; use cmath.isnan() if testing for not-a-number.
autoarray/inversion/inversion/imaging_numba/sparse.py:484 - Comparison of identical values; use cmath.isnan() if testing for not-a-number.
autoarray/dataset/imaging/dataset.py:589 - Keyword argument 'curvature_preload' is not a supported parameter name of SparseLinAlgImagingNumba.init.
sparse_operator = inversion_imaging_numba_util.SparseLinAlgImagingNumba(
curvature_preload=curvature_preload,
indexes=indexes.astype("int"),
lengths=lengths.astype("int"),
noise_map=self.noise_map,
psf=self.psf,
mask=self.mask,
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from jax.ops import segment_sum | ||
|
|
||
| w = psf_weighted_data[rows] # (nnz,) | ||
| contrib = vals * w # (nnz,) |
There was a problem hiding this comment.
The function data_vector_via_psf_weighted_data_from imports segment_sum from jax.ops unconditionally, but this function might be called with NumPy arrays based on the xp parameter pattern used elsewhere. This will cause an import error if JAX is not installed. The import should be conditional or the function should be clearly documented as JAX-only.
| from jax.ops import segment_sum | |
| w = psf_weighted_data[rows] # (nnz,) | |
| contrib = vals * w # (nnz,) | |
| w = psf_weighted_data[rows] # (nnz,) | |
| contrib = vals * w # (nnz,) | |
| try: | |
| from jax.ops import segment_sum | |
| except ImportError: | |
| # NumPy fallback implementation of segment_sum(contrib, cols, num_segments=S) | |
| data_vector = np.zeros(S, dtype=contrib.dtype) | |
| np.add.at(data_vector, cols, contrib) | |
| return data_vector |
| The w-tilde formalism bypasses the calculation of the `mapping_matrix` and it therefore cannot be used to map | ||
| the reconstruction's values to the image-plane. Instead, the unique data-to-pixelization mappings are used, |
There was a problem hiding this comment.
Inconsistency in docstring: The docstring still refers to "w-tilde formalism" but the code has been refactored to use "sparse operator" terminology. This should be updated for consistency.
| The w-tilde formalism bypasses the calculation of the `mapping_matrix` and it therefore cannot be used to map | |
| the reconstruction's values to the image-plane. Instead, the unique data-to-pixelization mappings are used, | |
| The sparse-operator formalism bypasses the explicit construction of the `mapping_matrix`. Instead, it uses | |
| the precomputed sparse data-to-pixelization mappings to map the reconstruction's values to the image-plane, |
| The sparse_operator matrix used by the w-tilde formalism to construct the data vector and | ||
| curvature matrix during an inversion efficiently.. |
There was a problem hiding this comment.
Inconsistency in docstring: The docstring still refers to "w-tilde formalism" in the parameter description, but the code has been refactored to use "sparse operator" terminology. This should be updated for consistency.
| The sparse_operator matrix used by the w-tilde formalism to construct the data vector and | |
| curvature matrix during an inversion efficiently.. | |
| A sparse operator used to efficiently construct the data vector and curvature matrix during an | |
| inversion. |
|
|
||
| for sub_size in range(1, 3): | ||
|
|
||
| print(sub_size) |
There was a problem hiding this comment.
Debug print statement left in test code. This should be removed before merging to production.
| from typing import Optional | ||
|
|
||
| from autoconf.fitsable import ndarray_via_fits_from, output_to_fits | ||
| from autoconf import cached_property |
There was a problem hiding this comment.
Import of 'cached_property' is not used.
| from autoconf import cached_property |
| @@ -1,91 +1,37 @@ | |||
| from dataclasses import dataclass | |||
| from functools import partial | |||
There was a problem hiding this comment.
Import of 'partial' is not used.
| from functools import partial |
| from tqdm import tqdm | ||
| import os | ||
| import time | ||
| from pathlib import Path |
There was a problem hiding this comment.
Import of 'Path' is not used.
| from pathlib import Path |
autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py
Outdated
Show resolved
Hide resolved
| from autoarray.inversion.inversion.imaging.sparse import ( | ||
| InversionImagingSparse, | ||
| ) |
There was a problem hiding this comment.
Import of 'InversionImagingSparse' is not used.
| from autoarray.inversion.inversion.imaging.sparse import ( | |
| InversionImagingSparse, | |
| ) |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 35 out of 36 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (9)
autoarray/inversion/inversion/imaging_numba/sparse.py:218
- The variable name
npis being checked against itself in the conditionif np is np:. This condition will always be True and should instead check whether the array module being used is NumPy (as opposed to JAX). The correct check should beif self._xp is np:to properly distinguish between NumPy and JAX backends.
autoarray/inversion/inversion/imaging_numba/sparse.py:307 - The condition
if np is np:will always evaluate to True. This should beif self._xp is np:to correctly check whether the NumPy backend is being used (as opposed to JAX). This same bug appears in multiple places in this file (lines 215-218, 304-307, 449-460, 484-496) where the backend check is intended but incorrectly implemented.
autoarray/inversion/inversion/imaging_numba/sparse.py:460 - The condition
if np is np:will always evaluate to True. This should beif self._xp is np:to correctly check whether the NumPy backend is being used (as opposed to JAX).
autoarray/inversion/inversion/imaging_numba/sparse.py:496 - The condition
if np is np:will always evaluate to True. This should beif self._xp is np:to correctly check whether the NumPy backend is being used (as opposed to JAX).
autoarray/inversion/inversion/imaging_numba/sparse.py:447 - The function call passes extra parameters
curvature_weights,mask, andpsf_kernelthat are not in the function signature. The function expectsdata_linear_func_matrixbut receivescurvature_weightsinstead. The caller should computedata_linear_func_matrixusinginversion_imaging_util.data_linear_func_matrix_from(curvature_weights, self.psf.native, self.mask)before calling this function, similar to how it's done in the abstract imaging class'sdata_linear_func_matrix_dictproperty.
autoarray/inversion/inversion/imaging_numba/sparse.py:215 - Comparison of identical values; use cmath.isnan() if testing for not-a-number.
autoarray/inversion/inversion/imaging_numba/sparse.py:304 - Comparison of identical values; use cmath.isnan() if testing for not-a-number.
autoarray/inversion/inversion/imaging_numba/sparse.py:449 - Comparison of identical values; use cmath.isnan() if testing for not-a-number.
autoarray/inversion/inversion/imaging_numba/sparse.py:484 - Comparison of identical values; use cmath.isnan() if testing for not-a-number.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def from_nufft_precision_operator( | ||
| self, | ||
| nufft_precision_operator: np.ndarray, | ||
| dirty_image: np.ndarray, | ||
| *, | ||
| batch_size: int = 128, | ||
| ): |
There was a problem hiding this comment.
Class methods or methods of a type deriving from type should have 'cls', rather than 'self', as their first parameter.
…rometer_util.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This pull request extends the the "w_tilde" formalism of interferometry linear algebra to imaging datasets, enabling efficiently modeling using sparse operators on GPU via JAX. Initial test runs indicate this enables lens modeling of Hubble Space Telescope data using pixelized sources in under 10 minutes, with the formalism enabling small amounts of VRAM use.
It also changes the codebase API to replace the previous use of the term "w_tilde" with a new "sparse operator" formalism for both imaging and interferometer datasets. The changes update class names, method signatures, and internal logic to use the new approach, resulting in improved clarity and maintainability. The most important changes are grouped below:
Refactoring from w_tilde to sparse operator formalism:
w_tilde(includingWTildeImagingandWTildeInterferometer) withsparse_operator(such asImagingSparseOperatorandInterferometerSparseLinAlg) throughout the codebase, including constructor arguments, attributes, and method names in both imaging and interferometer dataset classes. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12] [13]w_tildemodules and classes, includingautoarray/dataset/imaging/w_tilde.pyandautoarray/dataset/abstract/w_tilde.py. [1] [2]API and interface updates:
GridsDatasetand related classes to use theuse_sparse_operatorflag instead ofuse_w_tilde, ensuring consistent configuration for the new formalism. [1] [2] [3]These changes modernize the linear algebra approach in the codebase, improve naming consistency, and remove outdated code related to the previous w_tilde formalism.