Skip to content

Feature/fft jax imaging#204

Merged
Jammy2211 merged 41 commits intomainfrom
feature/fft_jax_imaging
Feb 3, 2026
Merged

Feature/fft jax imaging#204
Jammy2211 merged 41 commits intomainfrom
feature/fft_jax_imaging

Conversation

@Jammy2211
Copy link
Owner

This pull request extends the the "w_tilde" formalism of interferometry linear algebra to imaging datasets, enabling efficiently modeling using sparse operators on GPU via JAX. Initial test runs indicate this enables lens modeling of Hubble Space Telescope data using pixelized sources in under 10 minutes, with the formalism enabling small amounts of VRAM use.

It also changes the codebase API to replace the previous use of the term "w_tilde" with a new "sparse operator" formalism for both imaging and interferometer datasets. The changes update class names, method signatures, and internal logic to use the new approach, resulting in improved clarity and maintainability. The most important changes are grouped below:

Refactoring from w_tilde to sparse operator formalism:

  • Replaced all references to w_tilde (including WTildeImaging and WTildeInterferometer) with sparse_operator (such as ImagingSparseOperator and InterferometerSparseLinAlg) throughout the codebase, including constructor arguments, attributes, and method names in both imaging and interferometer dataset classes. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12] [13]
  • Removed the legacy w_tilde modules and classes, including autoarray/dataset/imaging/w_tilde.py and autoarray/dataset/abstract/w_tilde.py. [1] [2]

API and interface updates:

  • Updated the GridsDataset and related classes to use the use_sparse_operator flag instead of use_w_tilde, ensuring consistent configuration for the new formalism. [1] [2] [3]
  • Updated import statements and module references throughout the codebase to point to the new sparse operator modules and classes. [1] [2]

These changes modernize the linear algebra approach in the codebase, improve naming consistency, and remove outdated code related to the previous w_tilde formalism.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request refactors the interferometry linear algebra "w_tilde" formalism to "sparse operator" terminology and extends it to imaging datasets, enabling efficient GPU-based lens modeling via JAX. The changes rename classes, methods, and internal logic throughout the codebase while adding new FFT-based sparse operator support for imaging data.

Changes:

  • Renamed all w_tilde references to sparse_operator across classes, methods, and parameters (e.g., WTildeImagingImagingSparseOperator, apply_w_tilde()apply_sparse_operator())
  • Added JAX-based FFT sparse operator implementation for imaging datasets via new ImagingSparseOperator class
  • Removed legacy w_tilde modules and deprecated settings (use_w_tilde_numpy, use_source_loop)

Reviewed changes

Copilot reviewed 35 out of 36 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
test_autoarray/inversion/inversion/test_settings_dict.py Removed deprecated settings from test expectations
test_autoarray/inversion/inversion/test_inversion_util.py Renamed test and updated function calls to use new sparse operator API
test_autoarray/inversion/inversion/test_factory.py Updated test names and assertions from w_tilde to sparse_operator
test_autoarray/inversion/inversion/test_abstract.py Updated test names and method calls to use sparse operator terminology
test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py Refactored interferometer tests to use new sparse operator classes and removed deprecated test
test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py Updated imaging tests with new API but contains critical bugs and debug code
test_autoarray/inversion/inversion/imaging/test_imaging.py Updated imports to use new sparse operator classes
test_autoarray/dataset/interferometer/test_dataset.py Removed test for deprecated curvature preload metadata
test_autoarray/dataset/imaging/test_dataset.py Removed assertions checking w_tilde attributes
autoarray/util/init.py Fixed import path for inversion_imaging_numba module
autoarray/mock.py Removed deprecated MockInversionImagingWTilde import
autoarray/mask/mask_2d.py Added fft_index_for_masked_pixel property for sparse operator support
autoarray/inversion/pixelization/mappers/mapper_util.py Added sparse_triplets_from function for COO triplet construction
autoarray/inversion/pixelization/mappers/abstract.py Added sparse_triplets_data and sparse_triplets_curvature properties
autoarray/inversion/pixelization/border_relocator.py Renamed use_w_tilde to use_sparse_operator parameter
autoarray/inversion/mock/mock_inversion_imaging.py Removed MockInversionImagingWTilde and related mock classes
autoarray/inversion/inversion/settings.py Removed deprecated use_w_tilde_numpy and use_source_loop settings
autoarray/inversion/inversion/inversion_util.py Renamed functions from w_tilde to psf_weighted_noise terminology
autoarray/inversion/inversion/interferometer/sparse.py Renamed class from InversionInterferometerWTilde to InversionInterferometerSparse
autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py Renamed utility functions and dataclass to sparse operator terminology
autoarray/inversion/inversion/imaging_numba/sparse.py Renamed class to InversionImagingSparseNumba with critical bugs in xp checks
autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py Renamed functions and added SparseLinAlgImagingNumba class
autoarray/inversion/inversion/imaging/sparse.py New JAX-based InversionImagingSparse implementation
autoarray/inversion/inversion/imaging/inversion_imaging_util.py New ImagingSparseOperator dataclass with JAX FFT operations
autoarray/inversion/inversion/imaging/abstract.py Updated docstrings to reference sparse operator formalism
autoarray/inversion/inversion/factory.py Updated factory to create appropriate sparse operator inversion types
autoarray/inversion/inversion/dataset_interface.py Renamed w_tilde parameter to sparse_operator
autoarray/inversion/inversion/abstract.py Updated docstring terminology
autoarray/dataset/interferometer/w_tilde.py Deleted legacy WTildeInterferometer implementation
autoarray/dataset/interferometer/dataset.py Renamed apply_w_tilde to apply_sparse_operator method
autoarray/dataset/imaging/w_tilde.py Deleted legacy WTildeImaging implementation
autoarray/dataset/imaging/dataset.py Added apply_sparse_operator methods for both JAX and CPU implementations
autoarray/dataset/grids.py Renamed use_w_tilde to use_sparse_operator parameter
autoarray/dataset/abstract/w_tilde.py Deleted legacy AbstractWTilde base class
autoarray/init.py Updated exports to use new sparse operator classes
Comments suppressed due to low confidence (11)

autoarray/inversion/inversion/imaging_numba/sparse.py:218

  • The condition if np is np will always evaluate to True, making the else branch unreachable. This should likely be checking whether self._xp is np to determine if NumPy or JAX is being used, consistent with similar checks elsewhere in the codebase.
    autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py:874
  • Typo in docstring: "efficienctly" should be "efficiently".
    autoarray/inversion/inversion/imaging_numba/sparse.py:496
  • The condition if np is np will always evaluate to True, making the else branch unreachable. This should likely be checking whether self._xp is np to determine if NumPy or JAX is being used.
    autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py:920
  • The property decorator is defined but the function returns a call to psf_precision_operator_from which suggests this should be a regular method, not a property. Additionally, the function attempts to access self.psf_precision_operator which would cause infinite recursion if this were actually called as a property. This appears to be incorrect - either the decorator should be removed or the implementation needs to be fixed.
    autoarray/inversion/inversion/imaging_numba/sparse.py:307
  • The condition if np is np will always evaluate to True, making the else branch unreachable. This should likely be checking whether self._xp is np to determine if NumPy or JAX is being used.
    autoarray/inversion/inversion/imaging_numba/sparse.py:460
  • The condition if np is np will always evaluate to True, making the else branch unreachable. This should likely be checking whether self._xp is np to determine if NumPy or JAX is being used.
    autoarray/inversion/inversion/imaging_numba/sparse.py:215
  • Comparison of identical values; use cmath.isnan() if testing for not-a-number.
    autoarray/inversion/inversion/imaging_numba/sparse.py:304
  • Comparison of identical values; use cmath.isnan() if testing for not-a-number.
    autoarray/inversion/inversion/imaging_numba/sparse.py:449
  • Comparison of identical values; use cmath.isnan() if testing for not-a-number.
    autoarray/inversion/inversion/imaging_numba/sparse.py:484
  • Comparison of identical values; use cmath.isnan() if testing for not-a-number.
    autoarray/dataset/imaging/dataset.py:589
  • Keyword argument 'curvature_preload' is not a supported parameter name of SparseLinAlgImagingNumba.init.
        sparse_operator = inversion_imaging_numba_util.SparseLinAlgImagingNumba(
            curvature_preload=curvature_preload,
            indexes=indexes.astype("int"),
            lengths=lengths.astype("int"),
            noise_map=self.noise_map,
            psf=self.psf,
            mask=self.mask,
        )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +100 to +103
from jax.ops import segment_sum

w = psf_weighted_data[rows] # (nnz,)
contrib = vals * w # (nnz,)
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function data_vector_via_psf_weighted_data_from imports segment_sum from jax.ops unconditionally, but this function might be called with NumPy arrays based on the xp parameter pattern used elsewhere. This will cause an import error if JAX is not installed. The import should be conditional or the function should be clearly documented as JAX-only.

Suggested change
from jax.ops import segment_sum
w = psf_weighted_data[rows] # (nnz,)
contrib = vals * w # (nnz,)
w = psf_weighted_data[rows] # (nnz,)
contrib = vals * w # (nnz,)
try:
from jax.ops import segment_sum
except ImportError:
# NumPy fallback implementation of segment_sum(contrib, cols, num_segments=S)
data_vector = np.zeros(S, dtype=contrib.dtype)
np.add.at(data_vector, cols, contrib)
return data_vector

Copilot uses AI. Check for mistakes.
Comment on lines +501 to +502
The w-tilde formalism bypasses the calculation of the `mapping_matrix` and it therefore cannot be used to map
the reconstruction's values to the image-plane. Instead, the unique data-to-pixelization mappings are used,
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistency in docstring: The docstring still refers to "w-tilde formalism" but the code has been refactored to use "sparse operator" terminology. This should be updated for consistency.

Suggested change
The w-tilde formalism bypasses the calculation of the `mapping_matrix` and it therefore cannot be used to map
the reconstruction's values to the image-plane. Instead, the unique data-to-pixelization mappings are used,
The sparse-operator formalism bypasses the explicit construction of the `mapping_matrix`. Instead, it uses
the precomputed sparse data-to-pixelization mappings to map the reconstruction's values to the image-plane,

Copilot uses AI. Check for mistakes.
Comment on lines +53 to 54
The sparse_operator matrix used by the w-tilde formalism to construct the data vector and
curvature matrix during an inversion efficiently..
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistency in docstring: The docstring still refers to "w-tilde formalism" in the parameter description, but the code has been refactored to use "sparse operator" terminology. This should be updated for consistency.

Suggested change
The sparse_operator matrix used by the w-tilde formalism to construct the data vector and
curvature matrix during an inversion efficiently..
A sparse operator used to efficiently construct the data vector and curvature matrix during an
inversion.

Copilot uses AI. Check for mistakes.

for sub_size in range(1, 3):

print(sub_size)
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statement left in test code. This should be removed before merging to production.

Copilot uses AI. Check for mistakes.
from typing import Optional

from autoconf.fitsable import ndarray_via_fits_from, output_to_fits
from autoconf import cached_property
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'cached_property' is not used.

Suggested change
from autoconf import cached_property

Copilot uses AI. Check for mistakes.
@@ -1,91 +1,37 @@
from dataclasses import dataclass
from functools import partial
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'partial' is not used.

Suggested change
from functools import partial

Copilot uses AI. Check for mistakes.
from tqdm import tqdm
import os
import time
from pathlib import Path
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Path' is not used.

Suggested change
from pathlib import Path

Copilot uses AI. Check for mistakes.
Comment on lines +6 to +8
from autoarray.inversion.inversion.imaging.sparse import (
InversionImagingSparse,
)
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'InversionImagingSparse' is not used.

Suggested change
from autoarray.inversion.inversion.imaging.sparse import (
InversionImagingSparse,
)

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 35 out of 36 changed files in this pull request and generated 1 comment.

Comments suppressed due to low confidence (9)

autoarray/inversion/inversion/imaging_numba/sparse.py:218

  • The variable name np is being checked against itself in the condition if np is np:. This condition will always be True and should instead check whether the array module being used is NumPy (as opposed to JAX). The correct check should be if self._xp is np: to properly distinguish between NumPy and JAX backends.
    autoarray/inversion/inversion/imaging_numba/sparse.py:307
  • The condition if np is np: will always evaluate to True. This should be if self._xp is np: to correctly check whether the NumPy backend is being used (as opposed to JAX). This same bug appears in multiple places in this file (lines 215-218, 304-307, 449-460, 484-496) where the backend check is intended but incorrectly implemented.
    autoarray/inversion/inversion/imaging_numba/sparse.py:460
  • The condition if np is np: will always evaluate to True. This should be if self._xp is np: to correctly check whether the NumPy backend is being used (as opposed to JAX).
    autoarray/inversion/inversion/imaging_numba/sparse.py:496
  • The condition if np is np: will always evaluate to True. This should be if self._xp is np: to correctly check whether the NumPy backend is being used (as opposed to JAX).
    autoarray/inversion/inversion/imaging_numba/sparse.py:447
  • The function call passes extra parameters curvature_weights, mask, and psf_kernel that are not in the function signature. The function expects data_linear_func_matrix but receives curvature_weights instead. The caller should compute data_linear_func_matrix using inversion_imaging_util.data_linear_func_matrix_from(curvature_weights, self.psf.native, self.mask) before calling this function, similar to how it's done in the abstract imaging class's data_linear_func_matrix_dict property.
    autoarray/inversion/inversion/imaging_numba/sparse.py:215
  • Comparison of identical values; use cmath.isnan() if testing for not-a-number.
    autoarray/inversion/inversion/imaging_numba/sparse.py:304
  • Comparison of identical values; use cmath.isnan() if testing for not-a-number.
    autoarray/inversion/inversion/imaging_numba/sparse.py:449
  • Comparison of identical values; use cmath.isnan() if testing for not-a-number.
    autoarray/inversion/inversion/imaging_numba/sparse.py:484
  • Comparison of identical values; use cmath.isnan() if testing for not-a-number.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 612 to 618
def from_nufft_precision_operator(
self,
nufft_precision_operator: np.ndarray,
dirty_image: np.ndarray,
*,
batch_size: int = 128,
):
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class methods or methods of a type deriving from type should have 'cls', rather than 'self', as their first parameter.

Copilot uses AI. Check for mistakes.
Jammy2211 and others added 9 commits February 3, 2026 17:32
…rometer_util.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
0;10;1cMerge branch 'feature/fft_jax_imaging' of github.com:Jammy2211/PyAutoArray into feature/fft_jax_imaging
@Jammy2211 Jammy2211 merged commit a655bb9 into main Feb 3, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant