Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f84526a
small edits
Jan 15, 2026
c43b0f0
merge with many main updates
Jan 31, 2026
f1961a5
switch to inversion_imaging_util.w_tilde_data_imaging_from
Jan 31, 2026
6fcbd08
add pixel_triplets_from_subpixel_arrays_from
Jan 31, 2026
f9fc511
data vector stuff all jaxd
Jan 31, 2026
3c173ea
curvature calculationsnow use new code
Feb 1, 2026
ddce765
curvature_matrix_with_added_to_diag_from
Feb 1, 2026
2a112f7
end to end new curvaturw works
Feb 1, 2026
0e7a537
update and fix test__data_vector_via_w_tilde_data_two_methods_agree
Feb 1, 2026
a9c9e35
fix another unit test with an adaptiv e batch size
Feb 1, 2026
6a2af28
fix _data_vector_multi_mapper
Feb 1, 2026
5acea80
remove SPF operator
Feb 1, 2026
77c86d3
x2 mapper stuff now works
Feb 1, 2026
9e7b1de
fix now pointless unit ests
Feb 1, 2026
09d81cf
all inverison unitt ests pass meaning JAx conversion works
Feb 1, 2026
a84e98c
black
Feb 1, 2026
b991807
remove build_inv_noise_map
Feb 2, 2026
ccd4b44
w tilde removing throughout most of imaging
Feb 2, 2026
222ee04
interferometer refactor
Feb 2, 2026
7dc8446
w_tilde_data renamed to unot use w_tilde
Feb 2, 2026
fdc510f
weight map also precomputed
Feb 2, 2026
8ecdec0
more updates
Feb 2, 2026
ff1f04d
invversion module remove w_tilde naming
Feb 3, 2026
2af6279
remove sparse_linalg_numpy
Feb 3, 2026
5bdc680
remove use_source_loop
Feb 3, 2026
96378bb
refactor sparse triplets
Feb 3, 2026
4c1cc71
lots of renaming and refactoring for psarse operator API
Feb 3, 2026
9a5cce3
build server tests
Feb 3, 2026
6642e07
fix numba stuff
Feb 3, 2026
01524d0
remove test which fails due to numrical stability
Feb 3, 2026
ad89496
InterferometerSparseLinAlg -> InterferometerSparseOperator
Feb 3, 2026
1823c7c
Update autoarray/inversion/inversion/imaging/sparse.py
Jammy2211 Feb 3, 2026
6361503
Update autoarray/inversion/inversion/interferometer/inversion_interfe…
Jammy2211 Feb 3, 2026
1f487a3
urgh
Feb 3, 2026
c77944f
urgh
Feb 3, 2026
b8efd47
erm
Feb 3, 2026
d30cb17
Fix numba code
Feb 3, 2026
bc2f131
git push
Feb 3, 2026
5e8d05f
JAX 64 bit
Feb 3, 2026
8996212
stuff urgh
Feb 3, 2026
eb98911
finish
Feb 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jobs:
- name: Run tests
run: |
export ROOT_DIR=`pwd`
export JAX_ENABLE_X64=True
export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoConf
export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoArray
pushd PyAutoArray
Expand Down
15 changes: 9 additions & 6 deletions autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@
from . import util
from . import fixtures
from . import mock as m
from .dataset.interferometer.w_tilde import load_curvature_preload_if_compatible

from .dataset import preprocess
from .dataset.abstract.dataset import AbstractDataset
from .dataset.abstract.w_tilde import AbstractWTilde
from .dataset.grids import GridsInterface
from .dataset.imaging.dataset import Imaging
from .dataset.imaging.simulator import SimulatorImaging
from .dataset.imaging.w_tilde import WTildeImaging
from .dataset.interferometer.dataset import Interferometer
from .dataset.interferometer.simulator import SimulatorInterferometer
from .dataset.interferometer.w_tilde import WTildeInterferometer
from .dataset.dataset_model import DatasetModel
from .fit.fit_dataset import AbstractFit
from .fit.fit_dataset import FitDataset
Expand All @@ -46,9 +43,15 @@
from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh
from .inversion.pixelization.mesh.abstract import AbstractMesh
from .inversion.inversion.imaging.mapping import InversionImagingMapping
from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde
from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde
from .inversion.inversion.imaging.sparse import InversionImagingSparse
from .inversion.inversion.imaging.inversion_imaging_util import ImagingSparseOperator
from .inversion.inversion.interferometer.sparse import (
InversionInterferometerSparse,
)
from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping
from .inversion.inversion.interferometer.inversion_interferometer_util import (
InterferometerSparseOperator,
)
from .inversion.linear_obj.linear_obj import LinearObj
from .inversion.linear_obj.func_list import AbstractLinearObjFuncList
from .mask.derive.indexes_2d import DeriveIndexes2D
Expand Down
20 changes: 0 additions & 20 deletions autoarray/dataset/abstract/w_tilde.py

This file was deleted.

6 changes: 3 additions & 3 deletions autoarray/dataset/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
over_sample_size_lp: Union[int, Array2D],
over_sample_size_pixelization: Union[int, Array2D],
psf: Optional[Kernel2D] = None,
use_w_tilde: bool = False,
use_sparse_operator: bool = False,
):
"""
Contains grids of (y,x) Cartesian coordinates at the centre of every pixel in the dataset's image and
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
self._blurring = None
self._border_relocator = None

self.use_w_tilde = use_w_tilde
self.use_sparse_operator = use_sparse_operator

@property
def lp(self):
Expand Down Expand Up @@ -120,7 +120,7 @@ def border_relocator(self) -> BorderRelocator:
self._border_relocator = BorderRelocator(
mask=self.mask,
sub_size=self.over_sample_size_pixelization,
use_w_tilde=self.use_w_tilde,
use_sparse_operator=self.use_sparse_operator,
)

return self._border_relocator
Expand Down
105 changes: 84 additions & 21 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
from pathlib import Path
from typing import Optional, Union

from autoconf import cached_property

from autoarray.dataset.abstract.dataset import AbstractDataset
from autoarray.dataset.grids import GridsDataset
from autoarray.dataset.imaging.w_tilde import WTildeImaging
from autoarray.inversion.inversion.imaging.inversion_imaging_util import (
ImagingSparseOperator,
)
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.arrays.kernel_2d import Kernel2D
from autoarray.mask.mask_2d import Mask2D
from autoarray import type as ty

from autoarray import exc
from autoarray.operators.over_sampling import over_sample_util
from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util

from autoarray.inversion.inversion.imaging import inversion_imaging_util

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +33,7 @@ def __init__(
disable_fft_pad: bool = True,
use_normalized_psf: Optional[bool] = True,
check_noise_map: bool = True,
w_tilde: Optional[WTildeImaging] = None,
sparse_operator: Optional[ImagingSparseOperator] = None,
):
"""
An imaging dataset, containing the image data, noise-map, PSF and associated quantities
Expand Down Expand Up @@ -86,9 +87,9 @@ def __init__(
the PSF kernel does not change the overall normalization of the image when it is convolved with it.
check_noise_map
If True, the noise-map is checked to ensure all values are above zero.
w_tilde
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked
noise-map values given the PSF (see `inversion.inversion_util`). Pass the `WTildeImaging` object here to
sparse_operator
The sparse linear algebra formalism of the linear algebra equations precomputes the convolution of every pair of masked
noise-map values given the PSF (see `inversion.inversion_util`). Pass the `ImagingSparseOperator` object here to
enable this linear algebra formalism for pixelized reconstructions.
"""

Expand Down Expand Up @@ -191,17 +192,17 @@ def __init__(
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")

use_w_tilde = True if w_tilde is not None else False
use_sparse_operator = True if sparse_operator is not None else False

self.grids = GridsDataset(
mask=self.data.mask,
over_sample_size_lp=self.over_sample_size_lp,
over_sample_size_pixelization=self.over_sample_size_pixelization,
psf=self.psf,
use_w_tilde=use_w_tilde,
use_sparse_operator=use_sparse_operator,
)

self.w_tilde = w_tilde
self.sparse_operator = sparse_operator

@classmethod
def from_fits(
Expand Down Expand Up @@ -474,9 +475,13 @@ def apply_over_sampling(

return dataset

def apply_w_tilde(self, disable_fft_pad: bool = False):
def apply_sparse_operator(
self,
batch_size: int = 128,
disable_fft_pad: bool = False,
):
"""
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked
The sparse linear algebra formalism precomputes the convolution of every pair of masked
noise-map values given the PSF (see `inversion.inversion_util`).

The `WTilde` object stores these precomputed values in the imaging dataset ensuring they are only computed once
Expand All @@ -487,12 +492,66 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):

Returns
-------
WTildeImaging
Precomputed values used for the w tilde formalism of linear algebra calculations.
batch_size
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
which can be reduced to produce lower memory usage at the cost of speed
disable_fft_pad
The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
performed and the image is used as-is. This is normally used to avoid repadding data that has already been
padded.
use_jax
Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
"""

logger.info("IMAGING - Computing W-Tilde... May take a moment.")
sparse_operator = (
inversion_imaging_util.ImagingSparseOperator.from_noise_map_and_psf(
data=self.data,
noise_map=self.noise_map,
psf=self.psf.native,
batch_size=batch_size,
)
)

return Imaging(
data=self.data,
noise_map=self.noise_map,
psf=self.psf,
noise_covariance_matrix=self.noise_covariance_matrix,
over_sample_size_lp=self.over_sample_size_lp,
over_sample_size_pixelization=self.over_sample_size_pixelization,
disable_fft_pad=disable_fft_pad,
check_noise_map=False,
sparse_operator=sparse_operator,
)

def apply_sparse_operator_cpu(
self,
disable_fft_pad: bool = False,
):
"""
The sparse linear algebra formalism precomputes the convolution of every pair of masked
noise-map values given the PSF (see `inversion.inversion_util`).

The `WTilde` object stores these precomputed values in the imaging dataset ensuring they are only computed once
per analysis.

This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used,
ensuring efficient set up of the `Imaging` class.

Returns
-------
batch_size
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
which can be reduced to produce lower memory usage at the cost of speed
disable_fft_pad
The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
performed and the image is used as-is. This is normally used to avoid repadding data that has already been
padded.
use_jax
Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
"""
try:
import numba
except ModuleNotFoundError:
Expand All @@ -504,20 +563,24 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
)

from autoarray.inversion.inversion.imaging_numba import (
inversion_imaging_numba_util,
)

(
curvature_preload,
psf_precision_operator_sparse,
indexes,
lengths,
) = inversion_imaging_numba_util.w_tilde_curvature_preload_imaging_from(
) = inversion_imaging_numba_util.psf_precision_operator_sparse_from(
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
kernel_native=np.array(self.psf.native.array).astype("float64"),
native_index_for_slim_index=np.array(
self.mask.derive_indexes.native_for_slim
).astype("int"),
)

w_tilde = WTildeImaging(
curvature_preload=curvature_preload,
sparse_operator = inversion_imaging_numba_util.SparseLinAlgImagingNumba(
psf_precision_operator_sparse=psf_precision_operator_sparse,
indexes=indexes.astype("int"),
lengths=lengths.astype("int"),
noise_map=self.noise_map,
Expand All @@ -534,7 +597,7 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):
over_sample_size_pixelization=self.over_sample_size_pixelization,
disable_fft_pad=disable_fft_pad,
check_noise_map=False,
w_tilde=w_tilde,
sparse_operator=sparse_operator,
)

def output_to_fits(
Expand Down
100 changes: 0 additions & 100 deletions autoarray/dataset/imaging/w_tilde.py

This file was deleted.

Loading
Loading