diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cfc286039..0f3928e49 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -66,6 +66,7 @@ jobs: - name: Run tests run: | export ROOT_DIR=`pwd` + export JAX_ENABLE_X64=True export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoConf export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoArray pushd PyAutoArray diff --git a/autoarray/__init__.py b/autoarray/__init__.py index e5740115b..4a81a3b2a 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -9,17 +9,14 @@ from . import util from . import fixtures from . import mock as m -from .dataset.interferometer.w_tilde import load_curvature_preload_if_compatible + from .dataset import preprocess from .dataset.abstract.dataset import AbstractDataset -from .dataset.abstract.w_tilde import AbstractWTilde from .dataset.grids import GridsInterface from .dataset.imaging.dataset import Imaging from .dataset.imaging.simulator import SimulatorImaging -from .dataset.imaging.w_tilde import WTildeImaging from .dataset.interferometer.dataset import Interferometer from .dataset.interferometer.simulator import SimulatorInterferometer -from .dataset.interferometer.w_tilde import WTildeInterferometer from .dataset.dataset_model import DatasetModel from .fit.fit_dataset import AbstractFit from .fit.fit_dataset import FitDataset @@ -46,9 +43,15 @@ from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh from .inversion.pixelization.mesh.abstract import AbstractMesh from .inversion.inversion.imaging.mapping import InversionImagingMapping -from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde -from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde +from .inversion.inversion.imaging.sparse import InversionImagingSparse +from .inversion.inversion.imaging.inversion_imaging_util import ImagingSparseOperator +from .inversion.inversion.interferometer.sparse import ( + InversionInterferometerSparse, +) from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping +from .inversion.inversion.interferometer.inversion_interferometer_util import ( + InterferometerSparseOperator, +) from .inversion.linear_obj.linear_obj import LinearObj from .inversion.linear_obj.func_list import AbstractLinearObjFuncList from .mask.derive.indexes_2d import DeriveIndexes2D diff --git a/autoarray/dataset/abstract/w_tilde.py b/autoarray/dataset/abstract/w_tilde.py deleted file mode 100644 index 59f72147d..000000000 --- a/autoarray/dataset/abstract/w_tilde.py +++ /dev/null @@ -1,20 +0,0 @@ -from autoarray import exc - - -class AbstractWTilde: - def __init__(self, curvature_preload): - """ - Packages together all derived data quantities necessary to fit `data (e.g. `Imaging`, Interferometer`) using - an ` Inversion` via the w_tilde formalism. - - The w_tilde formalism performs linear algebra formalism in a way that speeds up the construction of the - simultaneous linear equations by bypassing the construction of a `mapping_matrix` and precomputing - operations like blurring or a Fourier transform. - - Parameters - ---------- - curvature_preload - A matrix which uses the imaging's noise-map and PSF to preload as much of the computation of the - curvature matrix as possible. - """ - self.curvature_preload = curvature_preload diff --git a/autoarray/dataset/grids.py b/autoarray/dataset/grids.py index 80d7bc451..666e47fe7 100644 --- a/autoarray/dataset/grids.py +++ b/autoarray/dataset/grids.py @@ -17,7 +17,7 @@ def __init__( over_sample_size_lp: Union[int, Array2D], over_sample_size_pixelization: Union[int, Array2D], psf: Optional[Kernel2D] = None, - use_w_tilde: bool = False, + use_sparse_operator: bool = False, ): """ Contains grids of (y,x) Cartesian coordinates at the centre of every pixel in the dataset's image and @@ -66,7 +66,7 @@ def __init__( self._blurring = None self._border_relocator = None - self.use_w_tilde = use_w_tilde + self.use_sparse_operator = use_sparse_operator @property def lp(self): @@ -120,7 +120,7 @@ def border_relocator(self) -> BorderRelocator: self._border_relocator = BorderRelocator( mask=self.mask, sub_size=self.over_sample_size_pixelization, - use_w_tilde=self.use_w_tilde, + use_sparse_operator=self.use_sparse_operator, ) return self._border_relocator diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index a98a443ac..401990c90 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -3,11 +3,11 @@ from pathlib import Path from typing import Optional, Union -from autoconf import cached_property - from autoarray.dataset.abstract.dataset import AbstractDataset from autoarray.dataset.grids import GridsDataset -from autoarray.dataset.imaging.w_tilde import WTildeImaging +from autoarray.inversion.inversion.imaging.inversion_imaging_util import ( + ImagingSparseOperator, +) from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.arrays.kernel_2d import Kernel2D from autoarray.mask.mask_2d import Mask2D @@ -15,7 +15,8 @@ from autoarray import exc from autoarray.operators.over_sampling import over_sample_util -from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util + +from autoarray.inversion.inversion.imaging import inversion_imaging_util logger = logging.getLogger(__name__) @@ -32,7 +33,7 @@ def __init__( disable_fft_pad: bool = True, use_normalized_psf: Optional[bool] = True, check_noise_map: bool = True, - w_tilde: Optional[WTildeImaging] = None, + sparse_operator: Optional[ImagingSparseOperator] = None, ): """ An imaging dataset, containing the image data, noise-map, PSF and associated quantities @@ -86,9 +87,9 @@ def __init__( the PSF kernel does not change the overall normalization of the image when it is convolved with it. check_noise_map If True, the noise-map is checked to ensure all values are above zero. - w_tilde - The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked - noise-map values given the PSF (see `inversion.inversion_util`). Pass the `WTildeImaging` object here to + sparse_operator + The sparse linear algebra formalism of the linear algebra equations precomputes the convolution of every pair of masked + noise-map values given the PSF (see `inversion.inversion_util`). Pass the `ImagingSparseOperator` object here to enable this linear algebra formalism for pixelized reconstructions. """ @@ -191,17 +192,17 @@ def __init__( if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: raise exc.KernelException("Kernel2D Kernel2D must be odd") - use_w_tilde = True if w_tilde is not None else False + use_sparse_operator = True if sparse_operator is not None else False self.grids = GridsDataset( mask=self.data.mask, over_sample_size_lp=self.over_sample_size_lp, over_sample_size_pixelization=self.over_sample_size_pixelization, psf=self.psf, - use_w_tilde=use_w_tilde, + use_sparse_operator=use_sparse_operator, ) - self.w_tilde = w_tilde + self.sparse_operator = sparse_operator @classmethod def from_fits( @@ -474,9 +475,13 @@ def apply_over_sampling( return dataset - def apply_w_tilde(self, disable_fft_pad: bool = False): + def apply_sparse_operator( + self, + batch_size: int = 128, + disable_fft_pad: bool = False, + ): """ - The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked + The sparse linear algebra formalism precomputes the convolution of every pair of masked noise-map values given the PSF (see `inversion.inversion_util`). The `WTilde` object stores these precomputed values in the imaging dataset ensuring they are only computed once @@ -487,12 +492,66 @@ def apply_w_tilde(self, disable_fft_pad: bool = False): Returns ------- - WTildeImaging - Precomputed values used for the w tilde formalism of linear algebra calculations. + batch_size + The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution, + which can be reduced to produce lower memory usage at the cost of speed + disable_fft_pad + The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming, + which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not + performed and the image is used as-is. This is normally used to avoid repadding data that has already been + padded. + use_jax + Whether to use JAX to compute W-Tilde. This requires JAX to be installed. """ - logger.info("IMAGING - Computing W-Tilde... May take a moment.") + sparse_operator = ( + inversion_imaging_util.ImagingSparseOperator.from_noise_map_and_psf( + data=self.data, + noise_map=self.noise_map, + psf=self.psf.native, + batch_size=batch_size, + ) + ) + + return Imaging( + data=self.data, + noise_map=self.noise_map, + psf=self.psf, + noise_covariance_matrix=self.noise_covariance_matrix, + over_sample_size_lp=self.over_sample_size_lp, + over_sample_size_pixelization=self.over_sample_size_pixelization, + disable_fft_pad=disable_fft_pad, + check_noise_map=False, + sparse_operator=sparse_operator, + ) + def apply_sparse_operator_cpu( + self, + disable_fft_pad: bool = False, + ): + """ + The sparse linear algebra formalism precomputes the convolution of every pair of masked + noise-map values given the PSF (see `inversion.inversion_util`). + + The `WTilde` object stores these precomputed values in the imaging dataset ensuring they are only computed once + per analysis. + + This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used, + ensuring efficient set up of the `Imaging` class. + + Returns + ------- + batch_size + The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution, + which can be reduced to produce lower memory usage at the cost of speed + disable_fft_pad + The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming, + which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not + performed and the image is used as-is. This is normally used to avoid repadding data that has already been + padded. + use_jax + Whether to use JAX to compute W-Tilde. This requires JAX to be installed. + """ try: import numba except ModuleNotFoundError: @@ -504,11 +563,15 @@ def apply_w_tilde(self, disable_fft_pad: bool = False): "https://pyautolens.readthedocs.io/en/latest/installation/overview.html" ) + from autoarray.inversion.inversion.imaging_numba import ( + inversion_imaging_numba_util, + ) + ( - curvature_preload, + psf_precision_operator_sparse, indexes, lengths, - ) = inversion_imaging_numba_util.w_tilde_curvature_preload_imaging_from( + ) = inversion_imaging_numba_util.psf_precision_operator_sparse_from( noise_map_native=np.array(self.noise_map.native.array).astype("float64"), kernel_native=np.array(self.psf.native.array).astype("float64"), native_index_for_slim_index=np.array( @@ -516,8 +579,8 @@ def apply_w_tilde(self, disable_fft_pad: bool = False): ).astype("int"), ) - w_tilde = WTildeImaging( - curvature_preload=curvature_preload, + sparse_operator = inversion_imaging_numba_util.SparseLinAlgImagingNumba( + psf_precision_operator_sparse=psf_precision_operator_sparse, indexes=indexes.astype("int"), lengths=lengths.astype("int"), noise_map=self.noise_map, @@ -534,7 +597,7 @@ def apply_w_tilde(self, disable_fft_pad: bool = False): over_sample_size_pixelization=self.over_sample_size_pixelization, disable_fft_pad=disable_fft_pad, check_noise_map=False, - w_tilde=w_tilde, + sparse_operator=sparse_operator, ) def output_to_fits( diff --git a/autoarray/dataset/imaging/w_tilde.py b/autoarray/dataset/imaging/w_tilde.py deleted file mode 100644 index 0dc181f1f..000000000 --- a/autoarray/dataset/imaging/w_tilde.py +++ /dev/null @@ -1,100 +0,0 @@ -import logging -import numpy as np - -from autoarray.dataset.abstract.w_tilde import AbstractWTilde - -from autoarray.inversion.inversion.imaging import inversion_imaging_util -from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util - -logger = logging.getLogger(__name__) - - -class WTildeImaging(AbstractWTilde): - def __init__( - self, - curvature_preload: np.ndarray, - indexes: np.ndim, - lengths: np.ndarray, - noise_map: np.ndarray, - psf: np.ndarray, - mask: np.ndarray, - ): - """ - Packages together all derived data quantities necessary to fit `Imaging` data using an ` Inversion` via the - w_tilde formalism. - - The w_tilde formalism performs linear algebra formalism in a way that speeds up the construction of the - simultaneous linear equations by bypassing the construction of a `mapping_matrix` and precomputing the - blurring operations performed using the imaging's PSF. - - Parameters - ---------- - curvature_preload - A matrix which uses the imaging's noise-map and PSF to preload as much of the computation of the - curvature matrix as possible. - indexes - The image-pixel indexes of the curvature preload matrix, which are used to compute the curvature matrix - efficiently when performing an inversion. - lengths - The lengths of how many indexes each curvature preload contains, again used to compute the curvature - matrix efficienctly. - """ - super().__init__( - curvature_preload=curvature_preload, - ) - - self.indexes = indexes - self.lengths = lengths - self.noise_map = noise_map - self.psf = psf - self.mask = mask - - @property - def w_matrix(self): - """ - The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF - convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the - curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the - PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging - datasets. - - The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, - making it impossible to store in memory and its use in linear algebra calculations extremely. The method - `w_tilde_curvature_preload_imaging_from` describes a compressed representation that overcomes this hurdles. It is - advised `w_tilde` and this method are only used for testing. - - Parameters - ---------- - noise_map_native - The two dimensional masked noise-map of values which w_tilde is computed from. - kernel_native - The two dimensional PSF kernel that w_tilde encodes the convolution of. - native_index_for_slim_index - An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. - - Returns - ------- - ndarray - A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of - the curvature matrix. - """ - - return inversion_imaging_numba_util.w_tilde_curvature_imaging_from( - noise_map_native=self.noise_map.native.array, - kernel_native=self.psf.native.array, - native_index_for_slim_index=np.array( - self.mask.derive_indexes.native_for_slim - ).astype("int"), - ) - - @property - def psf_operator_matrix_dense(self): - - return inversion_imaging_util.psf_operator_matrix_dense_from( - kernel_native=self.psf.native.array, - native_index_for_slim_index=np.array( - self.mask.derive_indexes.native_for_slim - ).astype("int"), - native_shape=self.noise_map.shape_native, - correlate=False, - ) diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index c8224f6a6..922aff968 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -3,10 +3,13 @@ from typing import Optional from autoconf.fitsable import ndarray_via_fits_from, output_to_fits +from autoconf import cached_property from autoarray.dataset.abstract.dataset import AbstractDataset -from autoarray.dataset.interferometer.w_tilde import WTildeInterferometer from autoarray.dataset.grids import GridsDataset +from autoarray.inversion.inversion.interferometer.inversion_interferometer_util import ( + InterferometerSparseOperator, +) from autoarray.operators.transformer import TransformerDFT from autoarray.operators.transformer import TransformerNUFFT from autoarray.mask.mask_2d import Mask2D @@ -29,7 +32,7 @@ def __init__( uv_wavelengths: np.ndarray, real_space_mask: Mask2D, transformer_class=TransformerNUFFT, - w_tilde: Optional[WTildeInterferometer] = None, + sparse_operator: Optional[InterferometerSparseOperator] = None, raise_error_dft_visibilities_limit: bool = True, ): """ @@ -93,16 +96,16 @@ def __init__( real_space_mask=real_space_mask, ) - use_w_tilde = True if w_tilde is not None else False + use_sparse_operator = True if sparse_operator is not None else False self.grids = GridsDataset( mask=self.real_space_mask, over_sample_size_lp=self.over_sample_size_lp, over_sample_size_pixelization=self.over_sample_size_pixelization, - use_w_tilde=use_w_tilde, + use_sparse_operator=use_sparse_operator, ) - self.w_tilde = w_tilde + self.sparse_operator = sparse_operator if raise_error_dft_visibilities_limit: if ( @@ -158,9 +161,9 @@ def from_fits( transformer_class=transformer_class, ) - def apply_w_tilde( + def apply_sparse_operator( self, - curvature_preload=None, + nufft_precision_operator=None, batch_size: int = 128, chunk_k: int = 2048, show_progress: bool = False, @@ -168,7 +171,7 @@ def apply_w_tilde( use_jax: bool = False, ): """ - The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities + The sparse linear algebra equations precomputes the Fourier Transform of all the visibilities given the `uv_wavelengths` (see `inversion.inversion_util`). The `WTilde` object stores these precomputed values in the interferometer dataset ensuring they are only @@ -179,7 +182,7 @@ def apply_w_tilde( Parameters ---------- - curvature_preload + nufft_precision_operator An already computed curvature preload matrix for this dataset (e.g. loaded from hard-disk), to prevent long recalculations of this matrix for large datasets. batch_size @@ -192,20 +195,16 @@ def apply_w_tilde( Precomputed values used for the w tilde formalism of linear algebra calculations. """ - if curvature_preload is None: + if nufft_precision_operator is None: logger.info( - "INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours." + "INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, CPU run times may exceed hours." ) - curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from( - noise_map_real=self.noise_map.array.real, - uv_wavelengths=self.uv_wavelengths, - shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels, - grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array, + nufft_precision_operator = self.psf_precision_operator_from( chunk_k=chunk_k, - show_memory=show_memory, show_progress=show_progress, + show_memory=show_memory, use_jax=use_jax, ) @@ -215,10 +214,9 @@ def apply_w_tilde( use_adjoint_scaling=True, ) - w_tilde = WTildeInterferometer( - curvature_preload=curvature_preload, + sparse_operator = inversion_interferometer_util.InterferometerSparseOperator.from_nufft_precision_operator( + nufft_precision_operator=nufft_precision_operator, dirty_image=dirty_image.array, - real_space_mask=self.real_space_mask, batch_size=batch_size, ) @@ -228,7 +226,25 @@ def apply_w_tilde( noise_map=self.noise_map, uv_wavelengths=self.uv_wavelengths, transformer_class=lambda uv_wavelengths, real_space_mask: self.transformer, - w_tilde=w_tilde, + sparse_operator=sparse_operator, + ) + + def psf_precision_operator_from( + self, + chunk_k: int = 2048, + show_progress: bool = False, + show_memory: bool = False, + use_jax: bool = False, + ): + return inversion_interferometer_util.nufft_precision_operator_from( + noise_map_real=self.noise_map.array.real, + uv_wavelengths=self.uv_wavelengths, + shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels, + grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array, + chunk_k=chunk_k, + show_memory=show_memory, + show_progress=show_progress, + use_jax=use_jax, ) @property diff --git a/autoarray/dataset/interferometer/w_tilde.py b/autoarray/dataset/interferometer/w_tilde.py deleted file mode 100644 index 694c8a73f..000000000 --- a/autoarray/dataset/interferometer/w_tilde.py +++ /dev/null @@ -1,368 +0,0 @@ -import json -import hashlib -import numpy as np -from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union - -from autoarray.dataset.abstract.w_tilde import AbstractWTilde -from autoarray.mask.mask_2d import Mask2D - - -def _bbox_from_mask(mask_bool: np.ndarray) -> Tuple[int, int, int, int]: - """ - Return bbox (y_min, y_max, x_min, x_max) of the unmasked region. - mask_bool: True=masked, False=unmasked - """ - ys, xs = np.where(~mask_bool) - if ys.size == 0: - raise ValueError("Mask has no unmasked pixels; cannot compute bbox.") - return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max()) - - -def _mask_sha256(mask_bool: np.ndarray) -> str: - """ - Stable hash of the full boolean mask content (not just bbox). - """ - # Ensure contiguous, stable dtype - arr = np.ascontiguousarray(mask_bool.astype(np.uint8)) - return hashlib.sha256(arr.tobytes()).hexdigest() - - -def _as_pixel_scales_tuple(pixel_scales) -> Tuple[float, float]: - """ - Normalize pixel_scales to a stable 2-tuple of float. - Works with AutoArray pixel_scales objects or raw tuples. - """ - try: - # autoarray typically stores pixel_scales as tuple-like - return (float(pixel_scales[0]), float(pixel_scales[1])) - except Exception: - # fallback: treat as scalar - s = float(pixel_scales) - return (s, s) - - -def _np_float_tuple(x) -> Tuple[float, float]: - return (float(x[0]), float(x[1])) - - -def curvature_preload_metadata_from(real_space_mask) -> Dict[str, Any]: - """ - Build the minimal metadata required to decide whether a stored curvature_preload - can be reused for the current WTildeInterferometer instance. - - The preload depends on: - - the *rectangular FFT grid extent* used for offset evaluation (bbox / extent) - - pixel scales (radians per pixel) - - (usually) the exact mask shape and content (recommended to hash) - - Returns - ------- - dict - JSON-serializable metadata. - """ - mask_bool = np.asarray(real_space_mask, dtype=bool) - y_min, y_max, x_min, x_max = _bbox_from_mask(mask_bool) - y_extent = y_max - y_min + 1 - x_extent = x_max - x_min + 1 - - pixel_scales = _as_pixel_scales_tuple(real_space_mask.pixel_scales) - - meta = { - "format": "autoarray.w_tilde.curvature_preload.v1", - "mask_shape": tuple(mask_bool.shape), - "pixel_scales": pixel_scales, - "bbox_unmasked": (y_min, y_max, x_min, x_max), - "rect_shape": (y_extent, x_extent), - # full-content hash: safest way to prevent accidental reuse - "mask_sha256": _mask_sha256(mask_bool), - } - return meta - - -def is_preload_metadata_compatible( - real_space_mask, - meta: Dict[str, Any], - *, - require_mask_hash: bool = True, - atol: float = 0.0, -) -> Tuple[bool, str]: - """ - Compare loaded metadata against current instance. - - Parameters - ---------- - meta - Metadata dict loaded from disk. - require_mask_hash - If True, require the full mask sha256 to match (safest). - If False, only check bbox + shape + pixel scales. - atol - Tolerances for pixel scale comparisons (normally exact is fine - because these are configuration constants, but tolerances allow - for tiny float repr differences). - - Returns - ------- - (ok, reason) - ok: bool, True if compatible - reason: str, human-readable mismatch reason if not ok. - """ - current = curvature_preload_metadata_from(real_space_mask=real_space_mask) - - # 1) format version - if meta.get("format") != current["format"]: - return False, f"format mismatch: {meta.get('format')} != {current['format']}" - - # 2) mask shape - if tuple(meta.get("mask_shape", ())) != tuple(current["mask_shape"]): - return ( - False, - f"mask_shape mismatch: {meta.get('mask_shape')} != {current['mask_shape']}", - ) - - # 3) pixel scales - ps_saved = _np_float_tuple(meta.get("pixel_scales", (np.nan, np.nan))) - ps_curr = _np_float_tuple(current["pixel_scales"]) - - if not ( - np.isclose(ps_saved[0], ps_curr[0], atol=atol) - and np.isclose(ps_saved[1], ps_curr[1], atol=atol) - ): - return False, f"pixel_scales mismatch: {ps_saved} != {ps_curr}" - - # 4) bbox / rect shape - if tuple(meta.get("bbox_unmasked", ())) != tuple(current["bbox_unmasked"]): - return ( - False, - f"bbox_unmasked mismatch: {meta.get('bbox_unmasked')} != {current['bbox_unmasked']}", - ) - - if tuple(meta.get("rect_shape", ())) != tuple(current["rect_shape"]): - return ( - False, - f"rect_shape mismatch: {meta.get('rect_shape')} != {current['rect_shape']}", - ) - - # 5) full mask hash (optional but recommended) - if require_mask_hash: - if meta.get("mask_sha256") != current["mask_sha256"]: - return False, "mask_sha256 mismatch (mask content differs)" - - return True, "ok" - - -def load_curvature_preload_if_compatible( - file: Union[str, Path], - real_space_mask, - *, - require_mask_hash: bool = True, -) -> Optional[np.ndarray]: - """ - Load a saved curvature_preload if (and only if) it is compatible with the current mask geometry. - - Parameters - ---------- - file - Path to a previously saved NPZ. - require_mask_hash - If True, require the full mask content hash to match (safest). - If False, only bbox + shape + pixel scales are checked. - - Returns - ------- - np.ndarray - The loaded curvature_preload if compatible, otherwise raises ValueError. - """ - file = Path(file) - if file.suffix.lower() != ".npz": - file = file.with_suffix(".npz") - - if not file.exists(): - raise FileNotFoundError(str(file)) - - with np.load(file, allow_pickle=False) as npz: - if "curvature_preload" not in npz or "meta_json" not in npz: - msg = f"File does not contain required fields: {file}" - raise ValueError(msg) - - meta_json = str(npz["meta_json"].item()) - meta = json.loads(meta_json) - - ok, reason = is_preload_metadata_compatible( - meta=meta, - real_space_mask=real_space_mask, - require_mask_hash=require_mask_hash, - atol=1.0e-8, - ) - - if not ok: - raise ValueError(f"curvature_preload incompatible: {reason}") - - return np.asarray(npz["curvature_preload"]) - - -class WTildeInterferometer(AbstractWTilde): - def __init__( - self, - curvature_preload: np.ndarray, - dirty_image: np.ndarray, - real_space_mask: Mask2D, - batch_size: int = 128, - ): - """ - Packages together all derived data quantities necessary to fit `Interferometer` data using an ` Inversion` via - the w_tilde formalism. - - The w_tilde formalism performs linear algebra formalism in a way that speeds up the construction of the - simultaneous linear equations by bypassing the construction of a `mapping_matrix` and precomputing the - Fourier transform operations performed using the interferometer's `uv_wavelengths`. - - Parameters - ---------- - w_matrix - The w_tilde matrix used by the w-tilde formalism to construct the data vector and - curvature matrix during an inversion efficiently.. - curvature_preload - A matrix which uses the interferometer `uv_wavelengths` to preload as much of the computation of the - curvature matrix as possible. - dirty_image - The real-space image of the visibilities computed via the transform, which is used to construct the - curvature matrix. - real_space_mask - The 2D mask in real-space defining the area where the interferometer data's visibilities are observing - a signal. - batch_size - The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution, - which can be reduced to produce lower memory usage at the cost of speed. - """ - super().__init__( - curvature_preload=curvature_preload, - ) - - self.dirty_image = dirty_image - self.real_space_mask = real_space_mask - - from autoarray.inversion.inversion.interferometer import ( - inversion_interferometer_util, - ) - - self.fft_state = inversion_interferometer_util.w_tilde_fft_state_from( - curvature_preload=self.curvature_preload, batch_size=batch_size - ) - - @property - def mask_rectangular_w_tilde(self) -> np.ndarray: - """ - Returns a rectangular boolean mask that tightly bounds the unmasked region - of the interferometer mask. - - This rectangular mask is used for computing the W-tilde curvature matrix - via FFT-based convolution, which requires a full rectangular grid. - - Pixels outside the bounding box of the original mask are set to True - (masked), and pixels inside are False (unmasked). - - Returns - ------- - np.ndarray - Boolean mask of shape (Ny, Nx), where False denotes unmasked pixels. - """ - mask = self.real_space_mask - - ys, xs = np.where(~mask) - - y_min, y_max = ys.min(), ys.max() - x_min, x_max = xs.min(), xs.max() - - rect_mask = np.ones(mask.shape, dtype=bool) - rect_mask[y_min : y_max + 1, x_min : x_max + 1] = False - - return rect_mask - - @property - def rect_index_for_mask_index(self) -> np.ndarray: - """ - Mapping from masked-grid pixel indices to rectangular-grid pixel indices. - - This array enables extraction of a curvature matrix computed on a full - rectangular grid back to the original masked grid. - - If: - - C_rect is the curvature matrix computed on the rectangular grid - - idx = rect_index_for_mask_index - - then the masked curvature matrix is: - C_mask = C_rect[idx[:, None], idx[None, :]] - - Returns - ------- - np.ndarray - Array of shape (N_masked_pixels,), where each entry gives the - corresponding index in the rectangular grid (row-major order). - """ - mask = self.real_space_mask - rect_mask = self.mask_rectangular_w_tilde - - # Bounding box of the rectangular region - ys, xs = np.where(~rect_mask) - y_min, y_max = ys.min(), ys.max() - x_min, x_max = xs.min(), xs.max() - - rect_width = x_max - x_min + 1 - - # Coordinates of unmasked pixels in the original mask (slim order) - mask_ys, mask_xs = np.where(~mask) - - # Convert (y, x) → rectangular flat index - rect_indices = ((mask_ys - y_min) * rect_width + (mask_xs - x_min)).astype( - np.int32 - ) - - return rect_indices - - def save_curvature_preload( - self, - file: Union[str, Path], - *, - overwrite: bool = False, - ) -> Path: - """ - Save curvature_preload plus enough metadata to ensure it is only reused when safe. - - Uses NPZ so we can store: - - curvature_preload (array) - - meta_json (string) - - Parameters - ---------- - file - Path to save to. Recommended suffix: ".npz". - If you pass ".npy", we will still save an ".npz" next to it. - overwrite - If False and the file exists, raise FileExistsError. - - Returns - ------- - Path - The path actually written (will end with ".npz"). - """ - file = Path(file) - - # Force .npz (storing metadata safely) - if file.suffix.lower() != ".npz": - file = file.with_suffix(".npz") - - if file.exists() and not overwrite: - raise FileExistsError(f"File already exists: {file}") - - meta = curvature_preload_metadata_from(self.real_space_mask) - - meta_json = json.dumps(meta, sort_keys=True) - - np.savez_compressed( - file, - curvature_preload=np.asarray(self.curvature_preload), - meta_json=np.asarray(meta_json), - ) - return file diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index 9bc487868..54604a366 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -46,7 +46,7 @@ def __init__( The linear algebra required to perform an `Inversion` depends on the type of dataset being fitted (e.g. `Imaging`, `Interferometer) and the formalism chosen (e.g. a using a `mapping_matrix` or the - w_tilde formalism). The children of this class overwrite certain methods in order to be appropriate for + sparse linear algebra formalism). The children of this class overwrite certain methods in order to be appropriate for certain datasets or use a specific formalism. Inversions use the formalism's outlined in the following Astronomy papers: diff --git a/autoarray/inversion/inversion/dataset_interface.py b/autoarray/inversion/inversion/dataset_interface.py index cf5960bf8..4356b7aff 100644 --- a/autoarray/inversion/inversion/dataset_interface.py +++ b/autoarray/inversion/inversion/dataset_interface.py @@ -6,7 +6,7 @@ def __init__( grids=None, psf=None, transformer=None, - w_tilde=None, + sparse_operator=None, noise_covariance_matrix=None, ): """ @@ -49,8 +49,8 @@ def __init__( transformer Performs a Fourier transform of the image-data from real-space to visibilities when computing the operated mapping matrix. - w_tilde - The w_tilde matrix used by the w-tilde formalism to construct the data vector and + sparse_operator + The sparse_operator matrix used by the w-tilde formalism to construct the data vector and curvature matrix during an inversion efficiently.. noise_covariance_matrix A noise-map covariance matrix representing the covariance between noise in every `data` value, which @@ -61,7 +61,7 @@ def __init__( self.grids = grids self.psf = psf self.transformer = transformer - self.w_tilde = w_tilde + self.sparse_operator = sparse_operator self.noise_covariance_matrix = noise_covariance_matrix @property diff --git a/autoarray/inversion/inversion/factory.py b/autoarray/inversion/inversion/factory.py index ce04ca5db..b1e9edaf8 100644 --- a/autoarray/inversion/inversion/factory.py +++ b/autoarray/inversion/inversion/factory.py @@ -4,16 +4,25 @@ from autoarray.dataset.imaging.dataset import Imaging from autoarray.dataset.interferometer.dataset import Interferometer from autoarray.inversion.inversion.imaging.mapping import InversionImagingMapping + from autoarray.inversion.inversion.interferometer.mapping import ( InversionInterferometerMapping, ) -from autoarray.inversion.inversion.interferometer.w_tilde import ( - InversionInterferometerWTilde, +from autoarray.inversion.inversion.interferometer.sparse import ( + InversionInterferometerSparse, ) from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList -from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde +from autoarray.inversion.inversion.imaging_numba.inversion_imaging_numba_util import ( + SparseLinAlgImagingNumba, +) +from autoarray.inversion.inversion.imaging_numba.sparse import ( + InversionImagingSparseNumba, +) +from autoarray.inversion.inversion.imaging.sparse import ( + InversionImagingSparse, +) from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D @@ -78,7 +87,7 @@ def inversion_imaging_from( """ Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`. - Unlike the `inversion_from` factory this function takes the `data`, `noise_map` and `w_tilde` objects as separate + Unlike the `inversion_from` factory this function takes the `data` and `noise_map` objects as separate inputs, which facilitates certain computations where the `dataset` object is unpacked before the `Inversion` is performed (for example if the noise-map is scaled before the inversion to downweight certain regions of the data). @@ -108,21 +117,31 @@ def inversion_imaging_from( An `Inversion` whose type is determined by the input `dataset` and `settings`. """ - use_w_tilde = True + use_sparse_operator = True if all( isinstance(linear_obj, AbstractLinearObjFuncList) for linear_obj in linear_obj_list ): - use_w_tilde = False + use_sparse_operator = False - if dataset.w_tilde is not None and use_w_tilde: + if dataset.sparse_operator is not None and use_sparse_operator: - return InversionImagingWTilde( + if isinstance(dataset.sparse_operator, SparseLinAlgImagingNumba): + + return InversionImagingSparseNumba( + dataset=dataset, + linear_obj_list=linear_obj_list, + settings=settings, + preloads=preloads, + xp=xp, + ) + + return InversionImagingSparse( dataset=dataset, - w_tilde=dataset.w_tilde, linear_obj_list=linear_obj_list, settings=settings, + preloads=preloads, xp=xp, ) @@ -145,7 +164,7 @@ def inversion_interferometer_from( Factory which given an input `Interferometer` dataset and list of linear objects, creates an `InversionInterferometer`. - Unlike the `inversion_from` factory this function takes the `data`, `noise_map` and `w_tilde` objects as separate + Unlike the `inversion_from` factory this function takes the `data` and `noise_map` objects as separate inputs, which facilitates certain computations where the `dataset` object is unpacked before the `Inversion` is performed (for example if the noise-map is scaled before the inversion to downweight certain regions of the data). @@ -164,8 +183,6 @@ def inversion_interferometer_from( ---------- dataset The dataset (e.g. `Interferometer`) whose data is reconstructed via the `Inversion`. - w_tilde - Object which uses the `Imaging` dataset's PSF to perform the `Inversion` using the w-tilde formalism. linear_obj_list The list of linear objects (e.g. analytic functions, a mapper with a pixelized grid) which reconstruct the input dataset's data and whose values are solved for via the inversion. @@ -176,19 +193,18 @@ def inversion_interferometer_from( ------- An `Inversion` whose type is determined by the input `dataset` and `settings`. """ - use_w_tilde = True + use_sparse_operator = True if all( isinstance(linear_obj, AbstractLinearObjFuncList) for linear_obj in linear_obj_list ): - use_w_tilde = False + use_sparse_operator = False - if dataset.w_tilde is not None and use_w_tilde: + if dataset.sparse_operator is not None and use_sparse_operator: - return InversionInterferometerWTilde( + return InversionInterferometerSparse( dataset=dataset, - w_tilde=dataset.w_tilde, linear_obj_list=linear_obj_list, settings=settings, xp=xp, diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index ed001c158..b863c1c68 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -41,7 +41,7 @@ def __init__( The linear algebra required to perform an `Inversion` depends on the type of dataset being fitted (e.g. `Imaging`, `Interferometer) and the formalism chosen (e.g. a using a `mapping_matrix` or the - w_tilde formalism). The children of this class overwrite certain methods in order to be appropriate for + sparse linear algebra formalism). The children of this class overwrite certain methods in order to be appropriate for certain datasets or use a specific formalism. Inversions use the formalism's outlined in the following Astronomy papers: @@ -123,7 +123,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: This property returns a dictionary mapping every linear func object to its corresponded operated mapping matrix, which is used for constructing the matrices that perform the linear inversion in an efficent way - for the w_tilde calculation. + for the psf precision operator calculation. Returns ------- @@ -209,7 +209,7 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict: This property returns a dictionary mapping every mapper object to its corresponded operated mapping matrix, which is used for constructing the matrices that perform the linear inversion in an efficent way - for the w_tilde calculation. + for the psf precision operator calculation. Returns ------- diff --git a/autoarray/inversion/inversion/imaging/inversion_imaging_util.py b/autoarray/inversion/inversion/imaging/inversion_imaging_util.py index a9298835a..6001b5a89 100644 --- a/autoarray/inversion/inversion/imaging/inversion_imaging_util.py +++ b/autoarray/inversion/inversion/imaging/inversion_imaging_util.py @@ -1,91 +1,37 @@ +from dataclasses import dataclass +from functools import partial import numpy as np +from typing import Optional, List, Tuple -def psf_operator_matrix_dense_from( - kernel_native: np.ndarray, - native_index_for_slim_index: np.ndarray, # shape (N_pix, 2), native (y,x) coords of masked pixels - native_shape: tuple[int, int], - correlate: bool = True, -) -> np.ndarray: - """ - Construct a dense PSF operator W (N_pix x N_pix) that maps masked image pixels to masked image pixels. - - Parameters - ---------- - kernel_native : (Ky, Kx) PSF kernel. - native_index_for_slim_index : (N_pix, 2) array of int - Native (y, x) coords for each masked pixel. - native_shape : (Ny, Nx) - Native 2D image shape. - correlate : bool, default True - If True, use correlation convention (no kernel flip). - If False, use convolution convention (flip kernel). - - Returns - ------- - W : ndarray, shape (N_pix, N_pix) - Dense PSF operator. - """ - Ky, Kx = kernel_native.shape - ph, pw = Ky // 2, Kx // 2 - Ny, Nx = native_shape - N_pix = native_index_for_slim_index.shape[0] - - ker = kernel_native if correlate else kernel_native[::-1, ::-1] - - # Padded index grid: -1 everywhere, slim index where masked - index_padded = -np.ones((Ny + 2 * ph, Nx + 2 * pw), dtype=np.int64) - for p, (y, x) in enumerate(native_index_for_slim_index): - index_padded[y + ph, x + pw] = p - - # Neighborhood offsets - dy = np.arange(Ky) - ph - dx = np.arange(Kx) - pw - - W = np.zeros((N_pix, N_pix), dtype=float) - - for i, (y, x) in enumerate(native_index_for_slim_index): - yp = y + ph - xp = x + pw - for j, dy_ in enumerate(dy): - for k, dx_ in enumerate(dx): - neigh = index_padded[yp + dy_, xp + dx_] - if neigh >= 0: - W[i, neigh] += ker[j, k] - - return W - - -def w_tilde_data_imaging_from( - image_native: np.ndarray, - noise_map_native: np.ndarray, +def psf_weighted_data_from( + weight_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index, xp=np, ) -> np.ndarray: """ - The matrix w_tilde is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF convolution of + The sparse linear algebra uses a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging datasets. - When w_tilde is used to perform an inversion, the mapping matrices are not computed, meaning that they cannot be - used to compute the data vector. This method creates the vector `w_tilde_data` which allows for the data + When it is used to perform an inversion, the mapping matrices are not computed, meaning that they cannot be + used to compute the data vector. This method creates the vector `psf_weighted_data` which allows for the data vector to be computed efficiently without the mapping matrix. - The matrix w_tilde_data is dimensions [image_pixels] and encodes the PSF convolution with the `weight_map`, + The matrix psf_weighted_data is dimensions [image_pixels] and encodes the PSF convolution with the `weight_map`, where the weights are the image-pixel values divided by the noise-map values squared: weight = image / noise**2.0 Parameters ---------- - image_native - The two dimensional masked image of values which `w_tilde_data` is computed from. - noise_map_native - The two dimensional masked noise-map of values which `w_tilde_data` is computed from. + weight_map_native + The two dimensional masked weight-map of values the PSF convolution is computed from, which is the data + divided by the noise-map squared. kernel_native - The two dimensional PSF kernel that `w_tilde_data` encodes the convolution of. + The two dimensional PSF kernel that `psf_weighted_data` encodes the convolution of. native_index_for_slim_index An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. @@ -95,18 +41,12 @@ def w_tilde_data_imaging_from( A matrix that encodes the PSF convolution values between the imaging divided by the noise map**2 that enables efficient calculation of the data vector. """ - - # 1) weight map = image / noise^2 (safe where noise==0) - weight_map = xp.where( - noise_map_native > 0.0, image_native / (noise_map_native**2), 0.0 - ) - Ky, Kx = kernel_native.shape ph, pw = Ky // 2, Kx // 2 # 2) pad so neighbourhood gathers never go OOB padded = xp.pad( - weight_map, ((ph, ph), (pw, pw)), mode="constant", constant_values=0.0 + weight_map_native, ((ph, ph), (pw, pw)), mode="constant", constant_values=0.0 ) # 3) build broadcasted neighbourhood indices for all requested pixels @@ -127,6 +67,43 @@ def w_tilde_data_imaging_from( return xp.sum(patches * kernel_native[None, :, :], axis=(1, 2)) # (N,) +def data_vector_via_psf_weighted_data_from( + psf_weighted_data: np.ndarray, # (M_pix,) float64 + rows: np.ndarray, # (nnz,) int32 each triplet's data pixel (slim index) + cols: np.ndarray, # (nnz,) int32 source pixel index + vals: np.ndarray, # (nnz,) float64 mapping weights incl sub_fraction + S: int, # number of source pixels +) -> np.ndarray: + """ + Returns the data vector `D` from the `psf_weighted_data` matrix (see `psf_weighted_data_from`), which encodes the + the 1D image `d` and 1D noise-map values `\sigma` (see Warren & Dye 2003). + + This uses the sparse matrix triplet representation of the mapping matrix to efficiently compute the data vector + without having to compute the full mapping matrix. + + Computes: + D[p] = sum_{triplets t with col_t=p} vals[t] * weighted_data_slim[slim_rows[t]] + + Parameters + ---------- + psf_weighted_data + The matrix representing the PSF convolution of the imaging data divided by the noise-map squared. + rows + The row indices of the sparse mapping matrix triplet representation, which map to data pixels. + cols + The column indices of the sparse mapping matrix triplet representation, which map to source pixels. + vals + The values of the sparse mapping matrix triplet representation, which map the image pixels to source pixels. + S + The number of source pixels. + """ + from jax.ops import segment_sum + + w = psf_weighted_data[rows] # (nnz,) + contrib = vals * w # (nnz,) + return segment_sum(contrib, cols, num_segments=S) # (S,) + + def data_vector_via_blurred_mapping_matrix_from( blurred_mapping_matrix: np.ndarray, image: np.ndarray, noise_map: np.ndarray ) -> np.ndarray: @@ -203,3 +180,597 @@ def data_linear_func_matrix_from( blurred_list.append(blurred[~mask]) return np.stack(blurred_list, axis=1) # shape (n_unmasked, n_funcs) + + +def curvature_matrix_mirrored_from( + curvature_matrix, + *, + xp=np, +): + """ + Mirror a curvature matrix so that any non-zero entry C[i,j] + is copied to C[j,i]. + + Supports: + - NumPy (xp=np) + - JAX (xp=jax.numpy) + + Parameters + ---------- + curvature_matrix : (N, N) array + Possibly triangular / partially-filled curvature matrix. + + xp : module + np or jax.numpy + + Returns + ------- + curvature_matrix_mirrored : (N, N) array + Symmetric curvature matrix. + """ + + # Ensure array type + C = curvature_matrix + + # Boolean mask of where entries exist + mask = C != 0 + + # Copy entries symmetrically + C_T = xp.swapaxes(C, 0, 1) + + # Prefer non-zero values from either side + curvature_matrix_mirrored = xp.where(mask, C, C_T) + + return curvature_matrix_mirrored + + +def curvature_matrix_with_added_to_diag_from( + curvature_matrix, + value: float, + no_regularization_index_list: Optional[List[int]] = None, + *, + xp=np, +): + """ + Add a small stabilizing value to the diagonal entries of the curvature matrix. + + Supports: + - NumPy (xp=np): in-place update + - JAX (xp=jax.numpy): functional `.at[].add()` + + Parameters + ---------- + curvature_matrix : (N, N) array + Curvature matrix to modify. + + value : float + Value added to selected diagonal entries. + + no_regularization_index_list : list of int + Indices where diagonal should be boosted. + + xp : module + np or jax.numpy + + Returns + ------- + curvature_matrix : array + Updated matrix (new array in JAX, modified in NumPy). + """ + + if no_regularization_index_list is None: + return curvature_matrix + + inds = xp.asarray(no_regularization_index_list, dtype=xp.int32) + + if xp is np: + # ----------------------- + # NumPy: in-place update + # ----------------------- + curvature_matrix[inds, inds] += value + return curvature_matrix + + else: + # ----------------------- + # JAX: functional update + # ----------------------- + return curvature_matrix.at[inds, inds].add(value) + + +def mapped_reconstructed_image_via_sparse_operator_from( + reconstruction, # (S,) + rows, + cols, + vals, # (nnz,) + fft_index_for_masked_pixel, + data_shape: int, # y_shape * x_shape +): + import jax.numpy as jnp + from jax.ops import segment_sum + + reconstruction = jnp.asarray(reconstruction, dtype=jnp.float64) + rows = jnp.asarray(rows, dtype=jnp.int32) + cols = jnp.asarray(cols, dtype=jnp.int32) + vals = jnp.asarray(vals, dtype=jnp.float64) + + contrib = vals * reconstruction[cols] # (nnz,) + image_rect = segment_sum( + contrib, rows, num_segments=data_shape[0] * data_shape[1] + ) # (M_rect,) + + image_slim = image_rect[fft_index_for_masked_pixel] # (M_pix,) + return image_slim + + +@dataclass(frozen=True) +class ImagingSparseOperator: + + data_native: np.ndarray + noise_map_native: np.ndarray + weight_map: np.ndarray + inverse_variances_native: "jax.Array" # (y, x) float64 + y_shape: int + x_shape: int + Ky: int + Kx: int + fft_shape: Tuple[int, int] + batch_size: int + col_offsets: "jax.Array" # (batch_size,) int32 + Khat_r: "jax.Array" # (Fy, Fx//2+1) complex + Khat_flip_r: "jax.Array" # (Fy, Fx//2+1) complex + + @classmethod + def from_noise_map_and_psf( + cls, + data, + noise_map, + psf, + *, + batch_size: int = 128, + dtype=None, + ) -> "ImagingSparseOperator": + """ + Construct an `ImagingSparseOperator` from imaging arrays and a PSF. + + This factory method builds all static FFT state required to apply + + W = Hᵀ N⁻¹ H + + repeatedly during curvature matrix construction: + + - `inverse_variances_native` is computed from `noise_map` as 1/noise^2 (masked-safe). + - `weight_map` is computed from `data` as data/noise^2 (masked-safe). + - `Khat_r` and `Khat_flip_r` are precomputed as rFFT transforms of the padded PSF + and its flipped version on the required FFT shape. + + Parameters + ---------- + data + Imaging data object (e.g. `Array2D`) providing `.array` (slim) and `.native`. + noise_map + Imaging noise-map object (e.g. `Array2D`) providing `.array` (slim), + `.native`, `.mask`, and `.shape_native`. + psf + PSF kernel in native 2D form. Can be NumPy or JAX convertible. + batch_size + Number of curvature columns processed per FFT batch. Larger values can improve + GPU utilization but increase VRAM usage. + dtype + JAX dtype for PSF / FFT precompute. Defaults to `jax.numpy.float64`. + + Returns + ------- + ImagingSparseOperator + Fully initialized operator containing all precomputed FFT state. + + Notes + ----- + - This method assumes the FFT shape is `(Hy+Ky-1, Hx+Kx-1)` which is sufficient + for linear convolution followed by cropping to “same”. + - Mask handling: the `Array2D(..., mask=...)` wrapper ensures masked pixels are + correctly represented in native arrays; you may still explicitly zero masked + pixels if your downstream code expects it. + """ + import jax.numpy as jnp + + from autoarray.structures.arrays.uniform_2d import Array2D + + if dtype is None: + dtype = jnp.float64 + + # ---------------------------- + # Shapes + # ---------------------------- + y_shape = int(noise_map.shape_native[0]) + x_shape = int(noise_map.shape_native[1]) + + # ---------------------------- + # inverse_variances_native (native 2D) + # Make safe (0 where invalid) + # ---------------------------- + # Try to get a plain native ndarray from your Array2D-like object: + inverse_variances_native = 1.0 / noise_map**2 + inverse_variances_native = Array2D( + values=inverse_variances_native, mask=noise_map.mask + ) + inverse_variances_native = inverse_variances_native.native + + weight_map = data.array / (noise_map.array**2) + weight_map = Array2D(values=weight_map, mask=noise_map.mask) + + # If you *also* want to zero masked pixels explicitly: + # mask_native = noise_map.mask (depends on your API; might be bool native) + # inverse_variances_native = inverse_variances_native.at[mask_native].set(0.0) + + # ---------------------------- + # PSF + FFT precompute + # ---------------------------- + psf = jnp.asarray(psf, dtype=dtype) + Ky, Kx = map(int, psf.shape) + + fft_shape = (y_shape + Ky - 1, x_shape + Kx - 1) + Fy, Fx = fft_shape + + def precompute(psf2d): + psf_pad = jnp.pad(psf2d, ((0, Fy - Ky), (0, Fx - Kx))) + return jnp.fft.rfft2(psf_pad, s=(Fy, Fx)) + + Khat_r = precompute(psf) + Khat_flip_r = precompute(jnp.flip(psf, axis=(0, 1))) + + return cls( + data_native=data.native, + noise_map_native=noise_map.native, + weight_map=weight_map.native, + inverse_variances_native=inverse_variances_native, + y_shape=y_shape, + x_shape=x_shape, + Ky=Ky, + Kx=Kx, + fft_shape=(int(Fy), int(Fx)), + batch_size=int(batch_size), + col_offsets=jnp.arange(batch_size, dtype=jnp.int32), + Khat_r=Khat_r, + Khat_flip_r=Khat_flip_r, + ) + + def apply_operator(self, Fbatch_flat): + """ + Apply the imaging precision operator W = Hᵀ N⁻¹ H to a batch of vectors. + + This is the fundamental linear operator application used throughout curvature + assembly. Given an input matrix `Fbatch_flat` with shape: + + (M_rect, B) + + where: + - `M_rect = y_shape * x_shape` is the number of pixels on the rectangular FFT grid, + - `B` is the number of right-hand-side vectors (typically `batch_size`), + + this method computes: + + G = W Fbatch_flat + + using two FFT-based convolutions: + 1) Forward blur: H (correlation with PSF) + 2) Weighting: multiply by inverse variances (N⁻¹) + 3) Backproject: Hᵀ (correlation with flipped PSF) + + Parameters + ---------- + Fbatch_flat + Array of shape (M_rect, B) representing B vectors on the rectangular grid. + + Returns + ------- + ndarray + Array of shape (M_rect, B) equal to W applied to the batch. + + Notes + ----- + - The PSF is applied via rFFT on a padded grid of shape `fft_shape`. + - The output is cropped back to the native `(y_shape, x_shape)` “same” region. + - This method expects rectangular-grid indexing (flat). If you have slim masked + indexing you must scatter/gather appropriately outside this function. + """ + import jax.numpy as jnp + + y_shape, x_shape = self.y_shape, self.x_shape + Ky, Kx = self.Ky, self.Kx + Fy, Fx = self.fft_shape + M = y_shape * x_shape + + B = Fbatch_flat.shape[1] + Fimg = Fbatch_flat.T.reshape((B, y_shape, x_shape)) + + # forward blur + Fpad = jnp.pad(Fimg, ((0, 0), (0, Fy - y_shape), (0, Fx - x_shape))) + Fhat = jnp.fft.rfft2(Fpad, s=(Fy, Fx)) + blurred_pad = jnp.fft.irfft2(Fhat * self.Khat_r[None, :, :], s=(Fy, Fx)) + + cy, cx = Ky // 2, Kx // 2 + blurred = blurred_pad[:, cy : cy + y_shape, cx : cx + x_shape] + + weighted = blurred * self.inverse_variances_native[None, :, :] + + # backprojection + Wpad = jnp.pad(weighted, ((0, 0), (0, Fy - y_shape), (0, Fx - x_shape))) + What = jnp.fft.rfft2(Wpad, s=(Fy, Fx)) + back_pad = jnp.fft.irfft2(What * self.Khat_flip_r[None, :, :], s=(Fy, Fx)) + back = back_pad[:, cy : cy + y_shape, cx : cx + x_shape] + + return back.reshape((B, M)).T # (M, B) + + def curvature_matrix_diag_from(self, rows, cols, vals, *, S: int): + """ + Compute the diagonal (mapper–mapper) curvature matrix block F = Aᵀ W A. + + This method computes the curvature matrix for a single mapper/operator `A` + represented in COO triplets (rows, cols, vals), using the FFT-backed W operator. + + Conceptually, if A is the (M_rect × S) mapping matrix, then: + + F = Aᵀ W A + + where: + - `S` is the number of source pixels / parameters for this mapper. + - `M_rect = y_shape * x_shape` is the number of rectangular grid pixels. + + The computation proceeds in column blocks of width `batch_size`: + 1) Assemble Fbatch = A[:, start:start+B] on the rectangular grid via scatter-add. + 2) Apply W to the block: Gbatch = W(Fbatch). + 3) Project back with Aᵀ via a segment_sum over `cols`. + + Parameters + ---------- + rows, cols, vals + COO triplets encoding the sparse mapping operator A. + Expected conventions: + - `rows`: rectangular-grid pixel indices (flat), shape (nnz,) + - `cols`: source pixel indices in [0, S), shape (nnz,) + - `vals`: mapping weights (incl sub-fraction / interpolation), shape (nnz,) + S + Number of source pixels / parameters for this mapper. + + Returns + ------- + ndarray + Curvature matrix of shape (S, S), symmetric. + + Notes + ----- + - The output is symmetrized as 0.5*(F + Fᵀ) to mitigate numerical asymmetry + introduced by floating-point reductions. + - Padding to `S_pad = ceil(S/B)*B` ensures `dynamic_update_slice` is always legal + even when S is not a multiple of batch_size. + """ + import jax.numpy as jnp + from jax import lax + from jax.ops import segment_sum + + rows = jnp.asarray(rows, dtype=jnp.int32) + cols = jnp.asarray(cols, dtype=jnp.int32) + vals = jnp.asarray(vals, dtype=jnp.float64) + + y_shape, x_shape = self.y_shape, self.x_shape + M = y_shape * x_shape + B = self.batch_size + + n_blocks = (S + B - 1) // B + S_pad = n_blocks * B + + C0 = jnp.zeros((S, S_pad), dtype=jnp.float64) + + def body(block_i, C): + start = block_i * B + + in_block = (cols >= start) & (cols < (start + B)) + bc = jnp.where(in_block, cols - start, 0).astype(jnp.int32) + v = jnp.where(in_block, vals, 0.0) + + F = jnp.zeros((M, B), dtype=jnp.float64) + F = F.at[rows, bc].add(v) + + G = self.apply_operator(F) # (M, B) + + contrib = vals[:, None] * G[rows, :] + Cblock = segment_sum(contrib, cols, num_segments=S) # (S, B) + + width = jnp.minimum(B, jnp.maximum(0, S - start)) + Cblock = Cblock * (self.col_offsets < width)[None, :] + + return lax.dynamic_update_slice(C, Cblock, (0, start)) + + C_pad = lax.fori_loop(0, n_blocks, body, C0) + C = C_pad[:, :S] + return 0.5 * (C + C.T) + + def curvature_matrix_off_diag_from( + self, rows0, cols0, vals0, rows1, cols1, vals1, *, S0: int, S1: int + ): + """ + Compute the off-diagonal curvature block F01 = A0ᵀ W A1. + + Given two sparse mapping operators: + - A0 : (M_rect × S0) + - A1 : (M_rect × S1) + + this method computes: + + F01 = A0ᵀ W A1 + + using the same FFT-backed W operator as in the diagonal computation. + + The computation proceeds in column blocks over A1: + 1) Assemble Fbatch = A1[:, start:start+B] on the rectangular grid. + 2) Apply W: Gbatch = W(Fbatch). + 3) Project with A0ᵀ via segment_sum over cols0. + + Parameters + ---------- + rows0, cols0, vals0 + COO triplets for A0. `rows0` are rectangular-grid indices. + rows1, cols1, vals1 + COO triplets for A1. `rows1` are rectangular-grid indices. + S0 + Number of source parameters for mapper/operator 0. + S1 + Number of source parameters for mapper/operator 1. + + Returns + ------- + ndarray + Off-diagonal curvature block of shape (S0, S1). + + Notes + ----- + - The result is *not* symmetrized here because it is not square in general. + If you need the symmetric counterpart, use F10 = F01ᵀ when A0/A1 share W. + - Padding to `S1_pad = ceil(S1/B)*B` ensures `dynamic_update_slice` is always legal. + """ + import jax.numpy as jnp + from jax import lax + from jax.ops import segment_sum + + rows0 = jnp.asarray(rows0, dtype=jnp.int32) + cols0 = jnp.asarray(cols0, dtype=jnp.int32) + vals0 = jnp.asarray(vals0, dtype=jnp.float64) + + rows1 = jnp.asarray(rows1, dtype=jnp.int32) + cols1 = jnp.asarray(cols1, dtype=jnp.int32) + vals1 = jnp.asarray(vals1, dtype=jnp.float64) + + y_shape, x_shape = self.y_shape, self.x_shape + M = y_shape * x_shape + B = self.batch_size + + n_blocks = (S1 + B - 1) // B + S1_pad = n_blocks * B + + F01_0 = jnp.zeros((S0, S1_pad), dtype=jnp.float64) + + def body(block_i, F01): + start = block_i * B + + in_block = (cols1 >= start) & (cols1 < (start + B)) + bc = jnp.where(in_block, cols1 - start, 0).astype(jnp.int32) + v = jnp.where(in_block, vals1, 0.0) + + F = jnp.zeros((M, B), dtype=jnp.float64) + F = F.at[rows1, bc].add(v) + + G = self.apply_operator(F) # (M, B) + + contrib = vals0[:, None] * G[rows0, :] + block = segment_sum(contrib, cols0, num_segments=S0) + + width = jnp.minimum(B, jnp.maximum(0, S1 - start)) + block = block * (self.col_offsets < width)[None, :] + + return lax.dynamic_update_slice(F01, block, (0, start)) + + F01_pad = lax.fori_loop(0, n_blocks, body, F01_0) + return F01_pad[:, :S1] + + def curvature_matrix_off_diag_func_list_from( + self, + curvature_weights, # (M_pix, n_funcs) + fft_index_for_masked_pixel, # (M_pix,) slim -> rect(flat) + rows, + cols, + vals, # triplets where rows are RECT indices + *, + S: int, + ): + """ + Compute the mapper–linear-function off-diagonal block: Aᵀ Hᵀ (curvature_weights). + + This method computes the off-diagonal block between a sparse mapper/operator A + and a set of fixed linear functions (e.g. linear light profiles) whose operated + images have already been PSF-convolved and noise-weighted. + + The input `curvature_weights` is assumed to already contain: + + curvature_weights = (H B) / noise^2 + + evaluated on the *slim masked* grid, for each linear function column. + + The returned matrix is: + + off_diag = Aᵀ [ Hᵀ (curvature_weights_native) ] + + which has shape (S, n_funcs). + + Parameters + ---------- + curvature_weights + Array of shape (M_pix, n_funcs) on the *slim masked* grid. + Each column corresponds to one linear function already convolved by H and + weighted by inverse variance. + fft_index_for_masked_pixel + Array of shape (M_pix,) mapping slim masked pixel indices to rectangular + FFT-grid flat indices. Used to scatter values onto the rectangular grid. + rows, cols, vals + COO triplets for the mapper A, where: + - `rows` are rectangular-grid indices (flat), shape (nnz,) + - `cols` are source pixel indices, shape (nnz,) + - `vals` are mapping weights, shape (nnz,) + S + Number of source pixels / parameters in the mapper. + + Returns + ------- + ndarray + Off-diagonal block of shape (S, n_funcs). + + Notes + ----- + - This method performs only one convolution per function column (batched), + using the flipped PSF (Hᵀ), because `curvature_weights` is assumed to already + include the forward blur H. + - No batch_size parameter is required here because we are convolving `n_funcs` + columns (often small) rather than sweeping over source pixels S. + """ + import jax.numpy as jnp + from jax.ops import segment_sum + + curvature_weights = jnp.asarray(curvature_weights, dtype=jnp.float64) + fft_index_for_masked_pixel = jnp.asarray( + fft_index_for_masked_pixel, dtype=jnp.int32 + ) + + rows = jnp.asarray(rows, dtype=jnp.int32) + cols = jnp.asarray(cols, dtype=jnp.int32) + vals = jnp.asarray(vals, dtype=jnp.float64) + + y_shape, x_shape = self.y_shape, self.x_shape + Ky, Kx = self.Ky, self.Kx + Fy, Fx = self.fft_shape + M_rect = y_shape * x_shape + + M_pix, n_funcs = curvature_weights.shape + + # 1) scatter slim -> rect(flat) + grid_flat = jnp.zeros((M_rect, n_funcs), dtype=jnp.float64) + grid_flat = grid_flat.at[fft_index_for_masked_pixel, :].set(curvature_weights) + + # 2) apply H^T: conv with flipped PSF + images = grid_flat.T.reshape((n_funcs, y_shape, x_shape)) # (B=n_funcs, Hy, Hx) + + # --- rfft conv (same as your helper) --- + images_pad = jnp.pad(images, ((0, 0), (0, Fy - y_shape), (0, Fx - x_shape))) + Fhat = jnp.fft.rfft2(images_pad, s=(Fy, Fx)) + out_pad = jnp.fft.irfft2(Fhat * self.Khat_flip_r[None, :, :], s=(Fy, Fx)) + + cy, cx = Ky // 2, Kx // 2 + back_native = out_pad[ + :, cy : cy + y_shape, cx : cx + x_shape + ] # (n_funcs, Hy, Hx) + + # 3) gather at mapper rows (rect coords) + back_flat = back_native.reshape((n_funcs, M_rect)).T # (M_rect, n_funcs) + back_at_rows = back_flat[rows, :] # (nnz, n_funcs) + + # 4) accumulate to source pixels + contrib = vals[:, None] * back_at_rows + return segment_sum(contrib, cols, num_segments=S) # (S, n_funcs) diff --git a/autoarray/inversion/inversion/imaging/sparse.py b/autoarray/inversion/inversion/imaging/sparse.py new file mode 100644 index 000000000..828ead97d --- /dev/null +++ b/autoarray/inversion/inversion/imaging/sparse.py @@ -0,0 +1,562 @@ +import numpy as np +from typing import Dict, List, Optional, Union + +from autoconf import cached_property + +from autoarray.dataset.imaging.dataset import Imaging +from autoarray.inversion.inversion.dataset_interface import DatasetInterface +from autoarray.inversion.inversion.imaging.abstract import AbstractInversionImaging +from autoarray.inversion.linear_obj.linear_obj import LinearObj +from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList +from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper +from autoarray.preloads import Preloads +from autoarray.structures.arrays.uniform_2d import Array2D + +from autoarray.inversion.inversion.imaging import inversion_imaging_util + + +class InversionImagingSparse(AbstractInversionImaging): + def __init__( + self, + dataset: Union[Imaging, DatasetInterface], + linear_obj_list: List[LinearObj], + settings: SettingsInversion = SettingsInversion(), + preloads: Preloads = None, + xp=np, + ): + """ + Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations + to be solved (see `inversion.inversion.abstract.AbstractInversion`) for a full description. + + A linear object describes the mappings between values in observed `data` and the linear object's model via its + `mapping_matrix`. This class constructs linear equations for `Imaging` objects, where the data is an image + and the mappings may include a convolution operation described by the imaging data's PSF. + + This class uses the w-tilde formalism, which speeds up the construction of the simultaneous linear equations by + bypassing the construction of a `mapping_matrix`. + + Parameters + ---------- + dataset + The dataset containing the image data, noise-map and psf which is fitted by the inversion. + linear_obj_list + The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed + the simultaneous linear equations are combined and solved simultaneously. + """ + + super().__init__( + dataset=dataset, + linear_obj_list=linear_obj_list, + settings=settings, + preloads=preloads, + xp=xp, + ) + + @cached_property + def psf_weighted_data(self): + return inversion_imaging_util.psf_weighted_data_from( + weight_map_native=self.dataset.sparse_operator.weight_map.array, + kernel_native=self.psf.stored_native, + native_index_for_slim_index=self.data.mask.derive_indexes.native_for_slim, + xp=self._xp, + ) + + @property + def _data_vector_mapper(self) -> np.ndarray: + """ + Returns the `data_vector` of all mappers, a 1D vector whose values are solved for by the simultaneous + linear equations constructed by this object. The object is described in full in the method `data_vector`. + + This method is used to compute part of the `data_vector` if there are also linear function list objects + in the inversion, and is separated into a separate method to enable preloading of the mapper `data_vector`. + """ + + if not self.has(cls=AbstractMapper): + return None + + data_vector = self._xp.zeros(self.total_params) + + mapper_list = self.cls_list_from(cls=AbstractMapper) + mapper_param_range = self.param_range_list_from(cls=AbstractMapper) + + for mapper_index, mapper in enumerate(mapper_list): + + rows, cols, vals = mapper.sparse_triplets_data + + data_vector_mapper = ( + inversion_imaging_util.data_vector_via_psf_weighted_data_from( + psf_weighted_data=self.psf_weighted_data, + rows=rows, + cols=cols, + vals=vals, + S=mapper.params, + ) + ) + param_range = mapper_param_range[mapper_index] + + start = param_range[0] + end = param_range[1] + + if self._xp is np: + data_vector[start:end] = data_vector_mapper + else: + data_vector = data_vector.at[start:end].set(data_vector_mapper) + + return data_vector + + @cached_property + def data_vector(self) -> np.ndarray: + """ + Returns the `data_vector`, a 1D vector whose values are solved for by the simultaneous linear equations + constructed by this object. + + The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf), where the + data vector is given by equation (4) and the letter D. + + If there are multiple linear objects a `data_vector` is computed for each one, which are concatenated + ensuring their values are solved for simultaneously. + + The calculation is described in more detail in `inversion_util.psf_weighted_data_from`. + """ + if self.has(cls=AbstractLinearObjFuncList): + return self._data_vector_func_list_and_mapper + elif self.total(cls=AbstractMapper) == 1: + return self._data_vector_x1_mapper + return self._data_vector_multi_mapper + + @property + def _data_vector_x1_mapper(self) -> np.ndarray: + """ + Returns the `data_vector`, a 1D vector whose values are solved for by the simultaneous linear equations + constructed by this object. The object is described in full in the method `data_vector`. + + This method computes the `data_vector` whenthere is a single mapper object in the `Inversion`, + which circumvents `np.concatenate` for speed up. + """ + linear_obj = self.linear_obj_list[0] + + rows, cols, vals = linear_obj.sparse_triplets_data + + return inversion_imaging_util.data_vector_via_psf_weighted_data_from( + psf_weighted_data=self.psf_weighted_data, + rows=rows, + cols=cols, + vals=vals, + S=linear_obj.params, + ) + + @property + def _data_vector_multi_mapper(self) -> np.ndarray: + """ + Returns the `data_vector`, a 1D vector whose values are solved for by the simultaneous linear equations + constructed by this object. The object is described in full in the method `data_vector`. + + This method computes the `data_vector` when there are multiple mapper objects in the `Inversion`, + which computes the `data_vector` of each object and concatenates them. + """ + + data_vector_list = [] + + for mapper in self.cls_list_from(cls=AbstractMapper): + + rows, cols, vals = mapper.sparse_triplets_data + + data_vector_mapper = ( + inversion_imaging_util.data_vector_via_psf_weighted_data_from( + psf_weighted_data=self.psf_weighted_data, + rows=rows, + cols=cols, + vals=vals, + S=mapper.params, + ) + ) + + data_vector_list.append(data_vector_mapper) + + return self._xp.concatenate(data_vector_list) + + @property + def _data_vector_func_list_and_mapper(self) -> np.ndarray: + """ + Returns the `data_vector`, a 1D vector whose values are solved for by the simultaneous linear equations + constructed by this object. The object is described in full in the method `data_vector`. + + This method computes the `data_vector` when there are one or more mapper objects in the `Inversion`, + which are combined with linear function list objects. + + The `data_vector` corresponding to all mapper objects is computed first, in a separate function. This + separation of functions enables the `data_vector` to be preloaded in certain circumstances. + """ + + data_vector = self._xp.array(self._data_vector_mapper) + + linear_func_param_range = self.param_range_list_from( + cls=AbstractLinearObjFuncList + ) + + for linear_func_index, linear_func in enumerate( + self.cls_list_from(cls=AbstractLinearObjFuncList) + ): + operated_mapping_matrix = self.linear_func_operated_mapping_matrix_dict[ + linear_func + ] + + diag = inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=operated_mapping_matrix, + image=self.data.array, + noise_map=self.noise_map.array, + ) + + param_range = linear_func_param_range[linear_func_index] + + start = param_range[0] + end = param_range[1] + + if self._xp is np: + data_vector[start:end] = diag + else: + data_vector = data_vector.at[start:end].set(diag) + + return data_vector + + @cached_property + def curvature_matrix(self) -> np.ndarray: + """ + Returns the `curvature_matrix`, a 2D matrix which uses the mappings between the data and the linear objects to + construct the simultaneous linear equations. + + The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf, where the + curvature matrix given by equation (4) and the letter F. + + This function computes F using the sparse linear algebra formalism, which is faster as it precomputes the PSF convolution + of different noise-map pixels (see `curvature_matrix_diag_via_sparse_operator_from`). + + If there are multiple linear objects the curvature_matrices are combined to ensure their values are solved + for simultaneously. In the w-tilde formalism this requires us to consider the mappings between data and every + linear object, meaning that the linear alegbra has both on and off diagonal terms. + + The `curvature_matrix` computed here is overwritten in memory when the regularization matrix is added to it, + because for large matrices this avoids overhead. For this reason, `curvature_matrix` is not a cached property + to ensure if we access it after computing the `curvature_reg_matrix` it is correctly recalculated in a new + array of memory. + """ + if self.has(cls=AbstractLinearObjFuncList): + curvature_matrix = self._curvature_matrix_func_list_and_mapper + elif self.total(cls=AbstractMapper) == 1: + curvature_matrix = self._curvature_matrix_x1_mapper + else: + curvature_matrix = self._curvature_matrix_multi_mapper + + curvature_matrix = inversion_imaging_util.curvature_matrix_mirrored_from( + curvature_matrix=curvature_matrix, + xp=self._xp, + ) + + if len(self.no_regularization_index_list) > 0: + curvature_matrix = ( + inversion_imaging_util.curvature_matrix_with_added_to_diag_from( + curvature_matrix=curvature_matrix, + value=self.settings.no_regularization_add_to_curvature_diag_value, + no_regularization_index_list=self.no_regularization_index_list, + xp=self._xp, + ) + ) + + return curvature_matrix + + @property + def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: + """ + Returns the diagonal regions of the `curvature_matrix`, a 2D matrix which uses the mappings between the data + and the linear objects to construct the simultaneous linear equations. The object is described in full in + the method `curvature_matrix`. + + This method computes the diagonal entries of all mapper objects in the `curvature_matrix`. It is separate from + other calculations to enable preloading of this calculation. + """ + + if not self.has(cls=AbstractMapper): + return None + + curvature_matrix = self._xp.zeros((self.total_params, self.total_params)) + + mapper_list = self.cls_list_from(cls=AbstractMapper) + mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) + + for i in range(len(mapper_list)): + mapper_i = mapper_list[i] + mapper_param_range_i = mapper_param_range_list[i] + + rows, cols, vals = mapper_i.sparse_triplets_curvature + + diag = self.dataset.sparse_operator.curvature_matrix_diag_from( + rows=rows, + cols=cols, + vals=vals, + S=mapper_i.params, + ) + + start, end = mapper_param_range_i + + if self._xp is np: + curvature_matrix[start:end, start:end] = diag + else: + curvature_matrix = curvature_matrix.at[start:end, start:end].set(diag) + + if self.total(cls=AbstractMapper) == 1: + return curvature_matrix + + return curvature_matrix + + def _curvature_matrix_off_diag_from( + self, mapper_0: AbstractMapper, mapper_1: AbstractMapper + ) -> np.ndarray: + """ + The `curvature_matrix` is a 2D matrix which uses the mappings between the data and the linear objects to + construct the simultaneous linear equations. + + The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf, where the + curvature matrix given by equation (4) and the letter F. + + This function computes the off-diagonal terms of F using the sparse linear algebra formalism. + """ + + rows0, cols0, vals0 = mapper_0.sparse_triplets_curvature + rows1, cols1, vals1 = mapper_1.sparse_triplets_curvature + + S0 = mapper_0.params + S1 = mapper_1.params + + return self.dataset.sparse_operator.curvature_matrix_off_diag_from( + rows0=rows0, + cols0=cols0, + vals0=vals0, + rows1=rows1, + cols1=cols1, + vals1=vals1, + S0=S0, + S1=S1, + ) + + @property + def _curvature_matrix_x1_mapper(self) -> np.ndarray: + """ + Returns the `curvature_matrix`, a 2D matrix which uses the mappings between the data and the linear objects to + construct the simultaneous linear equations. The object is described in full in the method `curvature_matrix`. + + This method computes the `curvature_matrix` when there is a single mapper object in the `Inversion`, + which circumvents `block_diag` for speed up. + """ + return self._curvature_matrix_mapper_diag + + @property + def _curvature_matrix_multi_mapper(self) -> np.ndarray: + """ + Returns the `curvature_matrix`, a 2D matrix which uses the mappings between the data and the linear objects to + construct the simultaneous linear equations. The object is described in full in the method `curvature_matrix`. + + This method computes the `curvature_matrix` when there are multiple mapper objects in the `Inversion`, + by computing each one (and their off-diagonal matrices) and combining them via the `block_diag` method. + """ + + curvature_matrix = self._curvature_matrix_mapper_diag + + if self.total(cls=AbstractMapper) == 1: + return curvature_matrix + + mapper_list = self.cls_list_from(cls=AbstractMapper) + mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) + + for i in range(len(mapper_list)): + mapper_i = mapper_list[i] + mapper_param_range_i = mapper_param_range_list[i] + + for j in range(i + 1, len(mapper_list)): + mapper_j = mapper_list[j] + mapper_param_range_j = mapper_param_range_list[j] + + off_diag = self._curvature_matrix_off_diag_from( + mapper_0=mapper_i, mapper_1=mapper_j + ) + + curvature_matrix[ + mapper_param_range_i[0] : mapper_param_range_i[1], + mapper_param_range_j[0] : mapper_param_range_j[1], + ] = off_diag + + return curvature_matrix + + @property + def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: + """ + The `curvature_matrix` is a 2D matrix which uses the mappings between the data and the linear objects to + construct the simultaneous linear equations. + + The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf, where the + curvature matrix given by equation (4) and the letter F. + + This function computes the diagonal terms of F using the sparse linear algebra formalism. + """ + + curvature_matrix = self._curvature_matrix_multi_mapper + + mapper_list = self.cls_list_from(cls=AbstractMapper) + mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) + + linear_func_list = self.cls_list_from(cls=AbstractLinearObjFuncList) + linear_func_param_range_list = self.param_range_list_from( + cls=AbstractLinearObjFuncList + ) + + for i in range(len(mapper_list)): + mapper = mapper_list[i] + mapper_param_range = mapper_param_range_list[i] + + for func_index, linear_func in enumerate(linear_func_list): + linear_func_param_range = linear_func_param_range_list[func_index] + + curvature_weights = ( + self.linear_func_operated_mapping_matrix_dict[linear_func] + / self.noise_map[:, None] ** 2 + ) + + rows, cols, vals = mapper.sparse_triplets_curvature + + off_diag = self.dataset.sparse_operator.curvature_matrix_off_diag_func_list_from( + curvature_weights=curvature_weights, + fft_index_for_masked_pixel=self.mask.fft_index_for_masked_pixel, + rows=rows, + cols=cols, + vals=vals, + S=mapper.params, + ) + + if self._xp is np: + + curvature_matrix[ + mapper_param_range[0] : mapper_param_range[1], + linear_func_param_range[0] : linear_func_param_range[1], + ] = off_diag + else: + + curvature_matrix = curvature_matrix.at[ + mapper_param_range[0] : mapper_param_range[1], + linear_func_param_range[0] : linear_func_param_range[1], + ].set(off_diag) + + for index_0, linear_func_0 in enumerate(linear_func_list): + + linear_func_param_range_0 = linear_func_param_range_list[index_0] + + weighted_vector_0 = ( + self.linear_func_operated_mapping_matrix_dict[linear_func_0] + / self.noise_map[:, None] + ) + + for index_1, linear_func_1 in enumerate(linear_func_list): + linear_func_param_range_1 = linear_func_param_range_list[index_1] + + weighted_vector_1 = ( + self.linear_func_operated_mapping_matrix_dict[linear_func_1] + / self.noise_map[:, None] + ) + + diag = self._xp.dot( + weighted_vector_0.T, + weighted_vector_1, + ) + + if self._xp is np: + + curvature_matrix[ + linear_func_param_range_0[0] : linear_func_param_range_0[1], + linear_func_param_range_1[0] : linear_func_param_range_1[1], + ] = diag + + else: + + curvature_matrix = curvature_matrix.at[ + linear_func_param_range_0[0] : linear_func_param_range_0[1], + linear_func_param_range_1[0] : linear_func_param_range_1[1], + ].set(diag) + + return curvature_matrix + + @property + def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: + """ + When constructing the simultaneous linear equations (via vectors and matrices) the quantities of each individual + linear object (e.g. their `mapping_matrix`) are combined into single ndarrays via stacking. This does not track + which quantities belong to which linear objects, therefore the linear equation's solutions (which are returned + as ndarrays) do not contain information on which linear object(s) they correspond to. + + For example, consider if two `Mapper` objects with 50 and 100 source pixels are used in an `Inversion`. + The `reconstruction` (which contains the solved for source pixels values) is an ndarray of shape [150], but + the ndarray itself does not track which values belong to which `Mapper`. + + This function converts an ndarray of a `reconstruction` to a dictionary of ndarrays containing each linear + object's reconstructed data values, where the keys are the instances of each mapper in the inversion. + + The w-tilde formalism bypasses the calculation of the `mapping_matrix` and it therefore cannot be used to map + the reconstruction's values to the image-plane. Instead, the unique data-to-pixelization mappings are used, + including the 2D convolution operation after mapping is complete. + + Parameters + ---------- + reconstruction + The reconstruction (in the source frame) whose values are mapped to a dictionary of values for each + individual mapper (in the image-plane). + """ + + mapped_reconstructed_data_dict = {} + + reconstruction_dict = self.source_quantity_dict_from( + source_quantity=self.reconstruction + ) + + for linear_obj in self.linear_obj_list: + reconstruction = reconstruction_dict[linear_obj] + + if isinstance(linear_obj, AbstractMapper): + + rows, cols, vals = linear_obj.sparse_triplets_curvature + + mapped_reconstructed_image = inversion_imaging_util.mapped_reconstructed_image_via_sparse_operator_from( + reconstruction=reconstruction, + rows=rows, + cols=cols, + vals=vals, + fft_index_for_masked_pixel=self.mask.fft_index_for_masked_pixel, + data_shape=self.mask.shape_native, + ) + + mapped_reconstructed_image = Array2D( + values=mapped_reconstructed_image, mask=self.mask + ) + + mapped_reconstructed_image = self.psf.convolved_image_from( + image=mapped_reconstructed_image, blurring_image=None, xp=self._xp + ).array + + mapped_reconstructed_image = Array2D( + values=mapped_reconstructed_image, mask=self.mask + ) + + else: + + operated_mapping_matrix = self.linear_func_operated_mapping_matrix_dict[ + linear_obj + ] + + mapped_reconstructed_image = self._xp.sum( + reconstruction * operated_mapping_matrix, axis=1 + ) + + mapped_reconstructed_image = Array2D( + values=mapped_reconstructed_image, mask=self.mask + ) + + mapped_reconstructed_data_dict[linear_obj] = mapped_reconstructed_image + + return mapped_reconstructed_data_dict diff --git a/autoarray/inversion/inversion/imaging_numba/__init__.py b/autoarray/inversion/inversion/imaging_numba/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/autoarray/inversion/inversion/imaging/inversion_imaging_numba_util.py b/autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py similarity index 71% rename from autoarray/inversion/inversion/imaging/inversion_imaging_numba_util.py rename to autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py index 963e98e8e..7b2ba9718 100644 --- a/autoarray/inversion/inversion/imaging/inversion_imaging_numba_util.py +++ b/autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py @@ -6,23 +6,23 @@ @numba_util.jit() -def w_tilde_data_imaging_from( +def psf_weighted_data_from( image_native: np.ndarray, noise_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index, ) -> np.ndarray: """ - The matrix w_tilde is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF convolution of + The sparse linear algebra uses a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging datasets. - When w_tilde is used to perform an inversion, the mapping matrices are not computed, meaning that they cannot be - used to compute the data vector. This method creates the vector `w_tilde_data` which allows for the data + When this is used to perform an inversion, the mapping matrices are not computed, meaning that they cannot be + used to compute the data vector. This method creates the vector `psf_weighted_data` which allows for the data vector to be computed efficiently without the mapping matrix. - The matrix w_tilde_data is dimensions [image_pixels] and encodes the PSF convolution with the `weight_map`, + The matrix psf_weighted_data is dimensions [image_pixels] and encodes the PSF convolution with the `weight_map`, where the weights are the image-pixel values divided by the noise-map values squared: weight = image / noise**2.0 @@ -30,11 +30,11 @@ def w_tilde_data_imaging_from( Parameters ---------- image_native - The two dimensional masked image of values which `w_tilde_data` is computed from. + The two dimensional masked image of values which `psf_weighted_data` is computed from. noise_map_native - The two dimensional masked noise-map of values which `w_tilde_data` is computed from. + The two dimensional masked noise-map of values which `psf_weighted_data` is computed from. kernel_native - The two dimensional PSF kernel that `w_tilde_data` encodes the convolution of. + The two dimensional PSF kernel that `psf_weighted_data` encodes the convolution of. native_index_for_slim_index An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. @@ -50,7 +50,7 @@ def w_tilde_data_imaging_from( image_pixels = len(native_index_for_slim_index) - w_tilde_data = np.zeros((image_pixels,)) + psf_weighted_data = np.zeros((image_pixels,)) weight_map_native = image_native / noise_map_native**2.0 @@ -68,17 +68,17 @@ def w_tilde_data_imaging_from( if not np.isnan(weight_value): value += kernel_native[k0_y, k0_x] * weight_value - w_tilde_data[ip0] = value + psf_weighted_data[ip0] = value - return w_tilde_data + return psf_weighted_data @numba_util.jit() -def w_tilde_curvature_imaging_from( +def psf_precision_operator_from( noise_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index ) -> np.ndarray: """ - The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF + The `psf_precision_operator` matrix is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging @@ -86,15 +86,15 @@ def w_tilde_curvature_imaging_from( The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, making it impossible to store in memory and its use in linear algebra calculations extremely. The method - `w_tilde_curvature_preload_imaging_from` describes a compressed representation that overcomes this hurdles. It is - advised `w_tilde` and this method are only used for testing. + `psf_precision_operator_sparse_from` describes a compressed representation that overcomes this hurdles. It is + advised `psf_precision_operator` and this method are only used for testing. Parameters ---------- noise_map_native - The two dimensional masked noise-map of values which w_tilde is computed from. + The two dimensional masked noise-map of values which psf_precision_operator is computed from. kernel_native - The two dimensional PSF kernel that w_tilde encodes the convolution of. + The two dimensional PSF kernel that psf_precision_operator encodes the convolution of. native_index_for_slim_index An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. @@ -106,15 +106,15 @@ def w_tilde_curvature_imaging_from( """ image_pixels = len(native_index_for_slim_index) - w_tilde_curvature = np.zeros((image_pixels, image_pixels)) + psf_precision_operator = np.zeros((image_pixels, image_pixels)) - for ip0 in range(w_tilde_curvature.shape[0]): + for ip0 in range(psf_precision_operator.shape[0]): ip0_y, ip0_x = native_index_for_slim_index[ip0] - for ip1 in range(ip0, w_tilde_curvature.shape[1]): + for ip1 in range(ip0, psf_precision_operator.shape[1]): ip1_y, ip1_x = native_index_for_slim_index[ip1] - w_tilde_curvature[ip0, ip1] += w_tilde_curvature_value_from( + psf_precision_operator[ip0, ip1] += psf_precision_value_from( value_native=noise_map_native, kernel_native=kernel_native, ip0_y=ip0_y, @@ -123,19 +123,19 @@ def w_tilde_curvature_imaging_from( ip1_x=ip1_x, ) - for ip0 in range(w_tilde_curvature.shape[0]): - for ip1 in range(ip0, w_tilde_curvature.shape[1]): - w_tilde_curvature[ip1, ip0] = w_tilde_curvature[ip0, ip1] + for ip0 in range(psf_precision_operator.shape[0]): + for ip1 in range(ip0, psf_precision_operator.shape[1]): + psf_precision_operator[ip1, ip0] = psf_precision_operator[ip0, ip1] - return w_tilde_curvature + return psf_precision_operator @numba_util.jit() -def w_tilde_curvature_preload_imaging_from( +def psf_precision_operator_sparse_from( noise_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ - The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF + The matrix `psf_precision_operator` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF convolution of every pair of image pixels on the noise map. This can be used to efficiently compute the curvature matrix via the mappings between image and source pixels, in a way that omits having to repeat the PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging @@ -143,25 +143,25 @@ def w_tilde_curvature_preload_imaging_from( The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, making it impossible to store in memory and its use in linear algebra calculations slow. This methods creates - a sparse matrix that can compute the matrix `w_tilde_curvature` efficiently, albeit the linear algebra calculations + a sparse matrix that can compute the matrix `psf_precision_operator` efficiently, albeit the linear algebra calculations in PyAutoArray bypass this matrix entirely to go straight to the curvature matrix. - for dataset data, w_tilde is a sparse matrix, whereby non-zero entries are only contained for pairs of image pixels + for imaging data, psf_precision_operator is a sparse matrix, whereby non-zero entries are only contained for pairs of image pixels where the two pixels overlap due to the kernel size. For example, if the kernel size is (11, 11) and two image pixels are separated by more than 20 pixels, the kernel will never convolve flux between the two pixels. Two image pixels will only share a convolution if they are within `kernel_overlap_size = 2 * kernel_shape - 1` pixels within one another. - Thus, a `w_tilde_curvature_preload` matrix of dimensions [image_pixels, kernel_overlap_size ** 2] can be computed + Thus, a `psf_precision_operator_preload` matrix of dimensions [image_pixels, kernel_overlap_size ** 2] can be computed which significantly reduces the memory consumption by removing the sparsity. Because the dimensions of the second - axes is no longer `image_pixels`, a second matrix `w_tilde_indexes` must also be computed containing the slim image - pixel indexes of every entry of `w_tilde_preload`. + axes is no longer `image_pixels`, a second matrix `psf_precision_indexes` must also be computed containing the slim image + pixel indexes of every entry of `psf_precision_operator`. - In order for the preload to store half the number of values, owing to the symmetry of the `w_tilde_curvature` + In order for the preload to store half the number of values, owing to the symmetry of the `psf_precision_operator` matrix, the image pixel pairs corresponding to the same image pixel are divided by two. This ensures that when the curvature matrix is computed these pixels are not double-counted. - The values stored in `w_tilde_curvature_preload` represent the convolution of overlapping noise-maps given the + The values stored in `psf_precision_operator_preload` represent the convolution of overlapping noise-maps given the PSF kernel. It is common for many values to be neglibly small. Removing these values can speed up the inversion and reduce memory at the expense of a numerically irrelevent change of solution. @@ -171,12 +171,12 @@ def w_tilde_curvature_preload_imaging_from( Parameters ---------- noise_map_native - The two dimensional masked noise-map of values which `w_tilde_curvature` is computed from. + The two dimensional masked noise-map of values which `psf_precision_operator` is computed from. signal_to_noise_map_native The two dimensional masked signal-to-noise-map from which the threshold discarding low S/N image pixel pairs is used. kernel_native - The two dimensional PSF kernel that `w_tilde_curvature` encodes the convolution of. + The two dimensional PSF kernel that `psf_precision_operator` encodes the convolution of. native_index_for_slim_index An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. @@ -193,19 +193,19 @@ def w_tilde_curvature_preload_imaging_from( 2 * kernel_native.shape[1] - 1 ) - curvature_preload_tmp = np.zeros((image_pixels, kernel_overlap_size)) - curvature_indexes_tmp = np.zeros((image_pixels, kernel_overlap_size)) - curvature_lengths = np.zeros(image_pixels) + psf_precision_operator_tmp = np.zeros((image_pixels, kernel_overlap_size)) + psf_precision_indexes_tmp = np.zeros((image_pixels, kernel_overlap_size)) + psf_precision_lengths = np.zeros(image_pixels) for ip0 in range(image_pixels): ip0_y, ip0_x = native_index_for_slim_index[ip0] kernel_index = 0 - for ip1 in range(ip0, curvature_preload_tmp.shape[0]): + for ip1 in range(ip0, psf_precision_operator_tmp.shape[0]): ip1_y, ip1_x = native_index_for_slim_index[ip1] - noise_value = w_tilde_curvature_value_from( + noise_value = psf_precision_value_from( value_native=noise_map_native, kernel_native=kernel_native, ip0_y=ip0_y, @@ -218,31 +218,31 @@ def w_tilde_curvature_preload_imaging_from( noise_value /= 2.0 if noise_value > 0.0: - curvature_preload_tmp[ip0, kernel_index] = noise_value - curvature_indexes_tmp[ip0, kernel_index] = ip1 + psf_precision_operator_tmp[ip0, kernel_index] = noise_value + psf_precision_indexes_tmp[ip0, kernel_index] = ip1 kernel_index += 1 - curvature_lengths[ip0] = kernel_index + psf_precision_lengths[ip0] = kernel_index - curvature_total_pairs = int(np.sum(curvature_lengths)) + psf_precision_total_pairs = int(np.sum(psf_precision_lengths)) - curvature_preload = np.zeros((curvature_total_pairs)) - curvature_indexes = np.zeros((curvature_total_pairs)) + psf_precision_operator = np.zeros((psf_precision_total_pairs)) + psf_precision_indexes = np.zeros((psf_precision_total_pairs)) index = 0 for i in range(image_pixels): - for data_index in range(int(curvature_lengths[i])): - curvature_preload[index] = curvature_preload_tmp[i, data_index] - curvature_indexes[index] = curvature_indexes_tmp[i, data_index] + for data_index in range(int(psf_precision_lengths[i])): + psf_precision_operator[index] = psf_precision_operator_tmp[i, data_index] + psf_precision_indexes[index] = psf_precision_indexes_tmp[i, data_index] index += 1 - return (curvature_preload, curvature_indexes, curvature_lengths) + return (psf_precision_operator, psf_precision_indexes, psf_precision_lengths) @numba_util.jit() -def w_tilde_curvature_value_from( +def psf_precision_value_from( value_native: np.ndarray, kernel_native: np.ndarray, ip0_y, @@ -252,7 +252,7 @@ def w_tilde_curvature_value_from( renormalize=False, ) -> float: """ - Compute the value of an entry of the `w_tilde_curvature` matrix, where this entry encodes the PSF convolution of + Compute the value of an entry of the `psf_precision_operator` matrix, where this entry encodes the PSF convolution of the noise-map between two image pixels. The calculation is performed by over-laying the PSF kernel over two noise-map pixels in 2D. For all pixels where @@ -264,15 +264,15 @@ def w_tilde_curvature_value_from( pixels and can therefore be used to efficiently calculate the curvature_matrix that is used in the linear algebra calculation of an inversion. - The sum of all values where kernel pixels overlap is returned to give the `w_tilde` value. + The sum of all values where kernel pixels overlap is returned to give the `psf_precision_operator` value. Parameters ---------- value_native - A two dimensional masked array of values (e.g. a noise-map, signal to noise map) which the w_tilde curvature + A two dimensional masked array of values (e.g. a noise-map, signal to noise map) which the psf_precision_operator curvature values are computed from. kernel_native - The two dimensional PSF kernel that w_tilde encodes the convolution of. + The two dimensional PSF kernel that psf_precision_operator encodes the convolution of. ip0_y The y index of the first image pixel in the image pixel pair. ip0_x @@ -285,7 +285,7 @@ def w_tilde_curvature_value_from( Returns ------- float - The w_tilde value that encodes the value of PSF convolution between a pair of image pixels. + The psf_precision_operator value that encodes the value of PSF convolution between a pair of image pixels. """ @@ -374,15 +374,15 @@ def data_vector_via_blurred_mapping_matrix_from( @numba_util.jit() -def data_vector_via_w_tilde_data_imaging_from( - w_tilde_data: np.ndarray, +def data_vector_via_psf_weighted_data_from( + psf_weighted_data: np.ndarray, data_to_pix_unique: np.ndarray, data_weights: np.ndarray, pix_lengths: np.ndarray, pix_pixels: int, ) -> np.ndarray: """ - Returns the data vector `D` from the `w_tilde_data` matrix (see `w_tilde_data_imaging_from`), which encodes the + Returns the data vector `D` from the `psf_weighted_data` matrix (see `psf_weighted_data_from`), which encodes the the 1D image `d` and 1D noise-map values `\sigma` (see Warren & Dye 2003). This uses the array `data_to_pix_unique`, which describes the unique mappings of every set of image sub-pixels to @@ -391,7 +391,7 @@ def data_vector_via_w_tilde_data_imaging_from( Parameters ---------- - w_tilde_data + psf_weighted_data A matrix that encodes the PSF convolution values between the imaging divided by the noise map**2 that enables efficient calculation of the data vector. data_to_pix_unique @@ -407,7 +407,7 @@ def data_vector_via_w_tilde_data_imaging_from( The total number of pixels in the pixelization that reconstructs the data. """ - data_pixels = w_tilde_data.shape[0] + data_pixels = psf_weighted_data.shape[0] data_vector = np.zeros(pix_pixels) @@ -416,7 +416,7 @@ def data_vector_via_w_tilde_data_imaging_from( data_0_weight = data_weights[data_0, pix_0_index] pix_0 = data_to_pix_unique[data_0, pix_0_index] - data_vector[pix_0] += data_0_weight * w_tilde_data[data_0] + data_vector[pix_0] += data_0_weight * psf_weighted_data[data_0] return data_vector @@ -469,27 +469,27 @@ def curvature_matrix_mirrored_from( @numba_util.jit() -def curvature_matrix_via_w_tilde_curvature_preload_imaging_from( - curvature_preload: np.ndarray, - curvature_indexes: np.ndarray, - curvature_lengths: np.ndarray, +def curvature_matrix_via_sparse_operator_from( + psf_precision_operator: np.ndarray, + psf_precision_indexes: np.ndarray, + psf_precision_lengths: np.ndarray, data_to_pix_unique: np.ndarray, data_weights: np.ndarray, pix_lengths: np.ndarray, pix_pixels: int, ) -> np.ndarray: """ - Returns the curvature matrix `F` (see Warren & Dye 2003) by computing it using `w_tilde_preload` - (see `w_tilde_preload_interferometer_from`) for an imaging inversion. + Returns the curvature matrix `F` (see Warren & Dye 2003) by computing it using `psf_precision_operator` + (see `psf_precision_operator_from`) for an imaging inversion. - To compute the curvature matrix via w_tilde the following matrix multiplication is normally performed: + To compute the curvature matrix via psf_precision_operator the following matrix multiplication is normally performed: - curvature_matrix = mapping_matrix.T * w_tilde * mapping matrix + curvature_matrix = mapping_matrix.T * psf_precision_operator * mapping matrix This function speeds this calculation up in two ways: - 1) Instead of using `w_tilde` (dimensions [image_pixels, image_pixels] it uses `w_tilde_preload` (dimensions - [image_pixels, kernel_overlap]). The massive reduction in the size of this matrix in memory allows for much fast + 1) Instead of using `psf_precision_operator` (dimensions [image_pixels, image_pixels] it uses a compressed + `psf_precision_operator` (dimensions [image_pixels, kernel_overlap]). The massive reduction in the size of this matrix in memory allows for much fast computation. 2) It omits the `mapping_matrix` and instead uses directly the 1D vector that maps every image pixel to a source @@ -498,13 +498,13 @@ def curvature_matrix_via_w_tilde_curvature_preload_imaging_from( Parameters ---------- - curvature_preload + psf_precision_operator A matrix that precomputes the values for fast computation of the curvature matrix in a memory efficient way. - curvature_indexes + psf_precision_indexes The image-pixel indexes of the values stored in the w tilde preload matrix, which are used to compute the weights of the data values when computing the curvature matrix. - curvature_lengths - The number of image pixels in every row of `w_tilde_curvature`, which is iterated over when computing the + psf_precision_lengths + The number of image pixels in every row of `psf_precision_operator`, which is iterated over when computing the curvature matrix. data_to_pix_unique An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of @@ -524,16 +524,16 @@ def curvature_matrix_via_w_tilde_curvature_preload_imaging_from( The curvature matrix `F` (see Warren & Dye 2003). """ - data_pixels = curvature_lengths.shape[0] + data_pixels = psf_precision_lengths.shape[0] curvature_matrix = np.zeros((pix_pixels, pix_pixels)) curvature_index = 0 for data_0 in range(data_pixels): - for data_1_index in range(curvature_lengths[data_0]): - data_1 = curvature_indexes[curvature_index] - w_tilde_value = curvature_preload[curvature_index] + for data_1_index in range(psf_precision_lengths[data_0]): + data_1 = psf_precision_indexes[curvature_index] + psf_precision_value = psf_precision_operator[curvature_index] for pix_0_index in range(pix_lengths[data_0]): data_0_weight = data_weights[data_0, pix_0_index] @@ -544,7 +544,7 @@ def curvature_matrix_via_w_tilde_curvature_preload_imaging_from( pix_1 = data_to_pix_unique[data_1, pix_1_index] curvature_matrix[pix_0, pix_1] += ( - data_0_weight * data_1_weight * w_tilde_value + data_0_weight * data_1_weight * psf_precision_value ) curvature_index += 1 @@ -561,10 +561,10 @@ def curvature_matrix_via_w_tilde_curvature_preload_imaging_from( @numba_util.jit() -def curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( - curvature_preload: np.ndarray, - curvature_indexes: np.ndarray, - curvature_lengths: np.ndarray, +def curvature_matrix_off_diags_via_sparse_operator_from( + psf_precision_operator: np.ndarray, + psf_precision_indexes: np.ndarray, + psf_precision_lengths: np.ndarray, data_to_pix_unique_0: np.ndarray, data_weights_0: np.ndarray, pix_lengths_0: np.ndarray, @@ -576,15 +576,15 @@ def curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( ) -> np.ndarray: """ Returns the off diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) by computing them - using `w_tilde_preload` (see `w_tilde_preload_interferometer_from`) for an imaging inversion. + using `psf_precision_operator` (see `psf_precision_operator_from`) for an imaging inversion. When there is more than one mapper in the inversion, its `mapping_matrix` is extended to have dimensions [data_pixels, sum(source_pixels_in_each_mapper)]. The curvature matrix therefore will have dimensions [sum(source_pixels_in_each_mapper), sum(source_pixels_in_each_mapper)]. - To compute the curvature matrix via w_tilde the following matrix multiplication is normally performed: + To compute the curvature matrix via the psf precision operator following matrix multiplication is normally performed: - curvature_matrix = mapping_matrix.T * w_tilde * mapping matrix + curvature_matrix = mapping_matrix.T * psf_precision_operator * mapping matrix When the `mapping_matrix` consists of multiple mappers from different planes, this means that shared data mappings between source-pixels in different mappers must be accounted for when computing the `curvature_matrix`. These @@ -592,17 +592,17 @@ def curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( This function evaluates these off-diagonal terms, by using the w-tilde curvature preloads and the unique data-to-pixelization mappings of each mapper. It behaves analogous to the - function `curvature_matrix_via_w_tilde_curvature_preload_imaging_from`. + function `curvature_matrix_via_sparse_operator_from`. Parameters ---------- - curvature_preload + psf_precision_operator A matrix that precomputes the values for fast computation of the curvature matrix in a memory efficient way. - curvature_indexes + psf_precision_indexes The image-pixel indexes of the values stored in the w tilde preload matrix, which are used to compute the weights of the data values when computing the curvature matrix. - curvature_lengths - The number of image pixels in every row of `w_tilde_curvature`, which is iterated over when computing the + psf_precision_lengths + The number of image pixels in every row of `psf_precision_operator`, which is iterated over when computing the curvature matrix. data_to_pix_unique An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of @@ -622,16 +622,16 @@ def curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( The curvature matrix `F` (see Warren & Dye 2003). """ - data_pixels = curvature_lengths.shape[0] + data_pixels = psf_precision_lengths.shape[0] curvature_matrix = np.zeros((pix_pixels_0, pix_pixels_1)) curvature_index = 0 for data_0 in range(data_pixels): - for data_1_index in range(curvature_lengths[data_0]): - data_1 = curvature_indexes[curvature_index] - w_tilde_value = curvature_preload[curvature_index] + for data_1_index in range(psf_precision_lengths[data_0]): + data_1 = psf_precision_indexes[curvature_index] + psf_precision_value = psf_precision_operator[curvature_index] for pix_0_index in range(pix_lengths_0[data_0]): data_0_weight = data_weights_0[data_0, pix_0_index] @@ -642,7 +642,7 @@ def curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( pix_1 = data_to_pix_unique_1[data_1, pix_1_index] curvature_matrix[pix_0, pix_1] += ( - data_0_weight * data_1_weight * w_tilde_value + data_0_weight * data_1_weight * psf_precision_value ) curvature_index += 1 @@ -650,99 +650,6 @@ def curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( return curvature_matrix -@numba_util.jit() -def curvature_matrix_off_diags_via_data_linear_func_matrix_from( - data_linear_func_matrix: np.ndarray, - data_to_pix_unique: np.ndarray, - data_weights: np.ndarray, - pix_lengths: np.ndarray, - pix_pixels: int, -): - """ - Returns the off diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) between a mapper object - and a linear func object, using the preloaded `data_linear_func_matrix` of the values of the linear functions. - - - If a linear function in an inversion is fixed, its values can be evaluated and preloaded beforehand. For every - data pixel, the PSF convolution with this preloaded linear function can also be preloaded, in a matrix of - shape [data_pixels, 1]. - - When mapper objects and linear functions are used simultaneously in an inversion, this preloaded matrix - significantly speed up the computation of their off-diagonal terms in the curvature matrix. - - This function performs this efficient calcluation via the preloaded `data_linear_func_matrix`. - - Parameters - ---------- - data_linear_func_matrix - A matrix of shape [data_pixels, total_fixed_linear_functions] that for each data pixel, maps it to the sum of - the values of a linear object function convolved with the PSF kernel at the data pixel. - data_to_pix_unique - The indexes of all pixels that each data pixel maps to (see the `Mapper` object). - data_weights - The weights of all pixels that each data pixel maps to (see the `Mapper` object). - pix_lengths - The number of pixelization pixels that each data pixel maps to (see the `Mapper` object). - pix_pixels - The number of pixelization pixels in the pixelization (see the `Mapper` object). - """ - - linear_func_pixels = data_linear_func_matrix.shape[1] - - off_diag = np.zeros((pix_pixels, linear_func_pixels)) - - data_pixels = data_weights.shape[0] - - for data_0 in range(data_pixels): - for pix_0_index in range(pix_lengths[data_0]): - data_0_weight = data_weights[data_0, pix_0_index] - pix_0 = data_to_pix_unique[data_0, pix_0_index] - - for linear_index in range(linear_func_pixels): - off_diag[pix_0, linear_index] += ( - data_linear_func_matrix[data_0, linear_index] * data_0_weight - ) - - return off_diag - - -@numba_util.jit() -def convolve_with_kernel_native(curvature_native, psf_kernel): - """ - Convolve each function slice of curvature_native with psf_kernel using direct sliding window. - - Parameters - ---------- - curvature_native : ndarray (ny, nx, n_funcs) - Curvature weights expanded to the native grid, 0 in masked regions. - psf_kernel : ndarray (ky, kx) - The PSF kernel. - - Returns - ------- - blurred_native : ndarray (ny, nx, n_funcs) - The curvature weights convolved with the PSF. - """ - ny, nx, n_funcs = curvature_native.shape - ky, kx = psf_kernel.shape - cy, cx = ky // 2, kx // 2 # kernel center - - blurred_native = np.zeros_like(curvature_native) - - for f in range(n_funcs): # parallelize over functions - for y in range(ny): - for x in range(nx): - acc = 0.0 - for dy in range(ky): - for dx in range(kx): - yy = y + dy - cy - xx = x + dx - cx - if 0 <= yy < ny and 0 <= xx < nx: - acc += psf_kernel[dy, dx] * curvature_native[yy, xx, f] - blurred_native[y, x, f] = acc - return blurred_native - - @numba_util.jit() def curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from( data_to_pix_unique: np.ndarray, @@ -827,6 +734,43 @@ def curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from( return off_diag +@numba_util.jit() +def convolve_with_kernel_native(curvature_native, psf_kernel): + """ + Convolve each function slice of curvature_native with psf_kernel using direct sliding window. + + Parameters + ---------- + curvature_native : ndarray (ny, nx, n_funcs) + Curvature weights expanded to the native grid, 0 in masked regions. + psf_kernel : ndarray (ky, kx) + The PSF kernel. + + Returns + ------- + blurred_native : ndarray (ny, nx, n_funcs) + The curvature weights convolved with the PSF. + """ + ny, nx, n_funcs = curvature_native.shape + ky, kx = psf_kernel.shape + cy, cx = ky // 2, kx // 2 # kernel center + + blurred_native = np.zeros_like(curvature_native) + + for f in range(n_funcs): # parallelize over functions + for y in range(ny): + for x in range(nx): + acc = 0.0 + for dy in range(ky): + for dx in range(kx): + yy = y + dy - cy + xx = x + dx - cx + if 0 <= yy < ny and 0 <= xx < nx: + acc += psf_kernel[dy, dx] * curvature_native[yy, xx, f] + blurred_native[y, x, f] = acc + return blurred_native + + @numba_util.jit() def mapped_reconstructed_data_via_image_to_pix_unique_from( data_to_pix_unique: np.ndarray, @@ -925,3 +869,80 @@ def relocated_grid_via_jit_from(grid, border_grid): ) return grid_relocated + + +class SparseLinAlgImagingNumba: + def __init__( + self, + psf_precision_operator_sparse: np.ndarray, + indexes: np.ndim, + lengths: np.ndarray, + noise_map: np.ndarray, + psf: np.ndarray, + mask: np.ndarray, + ): + """ + Packages together all derived data quantities necessary to fit `Imaging` data using an ` Inversion` via the + sparse linear algebra formalism. + + The sparse linear algebra formalism performs linear algebra formalism in a way that speeds up the construction of the + simultaneous linear equations by bypassing the construction of a `mapping_matrix` and precomputing the + blurring operations performed using the imaging's PSF. + + Parameters + ---------- + psf_precision_operator + A matrix which uses the imaging's noise-map and PSF to preload as much of the computation of the + curvature matrix as possible. + indexes + The image-pixel indexes of the curvature preload matrix, which are used to compute the curvature matrix + efficiently when performing an inversion. + lengths + The lengths of how many indexes each curvature preload contains, again used to compute the curvature + matrix efficienctly. + """ + + self.psf_precision_operator_sparse = psf_precision_operator_sparse + self.indexes = indexes + self.lengths = lengths + self.noise_map = noise_map + self.psf = psf + self.mask = mask + + @property + def psf_precision_operator(self): + """ + The matrix `psf_precision_operator` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF + convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the + curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the + PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging + datasets. + + The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, + making it impossible to store in memory and its use in linear algebra calculations extremely. The method + `psf_precision_operator_sparse_from` describes a compressed representation that overcomes this hurdles. It is + advised `psf_precision_operator` and this method are only used for testing. + + Parameters + ---------- + noise_map_native + The two dimensional masked noise-map of values which psf_precision_operator is computed from. + kernel_native + The two dimensional PSF kernel that psf_precision_operator encodes the convolution of. + native_index_for_slim_index + An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. + + Returns + ------- + ndarray + A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of + the curvature matrix. + """ + + return psf_precision_operator_from( + noise_map_native=self.noise_map.native.array, + kernel_native=self.psf.native.array, + native_index_for_slim_index=np.array( + self.mask.derive_indexes.native_for_slim + ).astype("int"), + ) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging_numba/sparse.py similarity index 85% rename from autoarray/inversion/inversion/imaging/w_tilde.py rename to autoarray/inversion/inversion/imaging_numba/sparse.py index bc2fa8bf0..8a3f6dab1 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging_numba/sparse.py @@ -4,26 +4,25 @@ from autoconf import cached_property from autoarray.dataset.imaging.dataset import Imaging -from autoarray.dataset.imaging.w_tilde import WTildeImaging from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.inversion.imaging.abstract import AbstractInversionImaging from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper +from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray import exc -from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util +from autoarray.inversion.inversion.imaging_numba import inversion_imaging_numba_util -class InversionImagingWTilde(AbstractInversionImaging): +class InversionImagingSparseNumba(AbstractInversionImaging): def __init__( self, dataset: Union[Imaging, DatasetInterface], - w_tilde: WTildeImaging, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), + preloads: Preloads = None, xp=np, ): """ @@ -41,34 +40,26 @@ def __init__( ---------- dataset The dataset containing the image data, noise-map and psf which is fitted by the inversion. - w_tilde - An object containing matrices that construct the linear equations via the w-tilde formalism which bypasses - the mapping matrix. linear_obj_list The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed the simultaneous linear equations are combined and solved simultaneously. """ - try: - import numba - except ModuleNotFoundError: - raise exc.InversionException( - "Inversion w-tilde functionality (pixelized reconstructions) is " - "disabled if numba is not installed.\n\n" - "This is because the run-times without numba are too slow.\n\n" - "Please install numba, which is described at the following web page:\n\n" - "https://pyautolens.readthedocs.io/en/latest/installation/overview.html" - ) - super().__init__( - dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, xp=xp + dataset=dataset, + linear_obj_list=linear_obj_list, + settings=settings, + preloads=preloads, + xp=xp, ) - self.w_tilde = dataset.w_tilde + @property + def sparse_operator(self): + return self.dataset.sparse_operator @cached_property - def w_tilde_data(self): - return inversion_imaging_numba_util.w_tilde_data_imaging_from( + def psf_weighted_data(self): + return inversion_imaging_numba_util.psf_weighted_data_from( image_native=np.array(self.data.native.array), noise_map_native=self.noise_map.native.array, kernel_native=self.psf.native.array, @@ -94,9 +85,10 @@ def _data_vector_mapper(self) -> np.ndarray: mapper_param_range = self.param_range_list_from(cls=AbstractMapper) for mapper_index, mapper in enumerate(mapper_list): + data_vector_mapper = ( - inversion_imaging_numba_util.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=self.w_tilde_data, + inversion_imaging_numba_util.data_vector_via_psf_weighted_data_from( + psf_weighted_data=self.psf_weighted_data, data_to_pix_unique=np.array( mapper.unique_mappings.data_to_pix_unique ), @@ -123,7 +115,7 @@ def data_vector(self) -> np.ndarray: If there are multiple linear objects a `data_vector` is computed for ech one, which are concatenated ensuring their values are solved for simultaneously. - The calculation is described in more detail in `inversion_util.w_tilde_data_imaging_from`. + The calculation is described in more detail in `inversion_util.psf_weighted_data_from`. """ if self.has(cls=AbstractLinearObjFuncList): return self._data_vector_func_list_and_mapper @@ -142,8 +134,8 @@ def _data_vector_x1_mapper(self) -> np.ndarray: """ linear_obj = self.linear_obj_list[0] - return inversion_imaging_numba_util.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=self.w_tilde_data, + return inversion_imaging_numba_util.data_vector_via_psf_weighted_data_from( + psf_weighted_data=self.psf_weighted_data, data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, data_weights=linear_obj.unique_mappings.data_weights, pix_lengths=linear_obj.unique_mappings.pix_lengths, @@ -160,18 +152,25 @@ def _data_vector_multi_mapper(self) -> np.ndarray: which computes the `data_vector` of each object and concatenates them. """ - return np.concatenate( - [ - inversion_imaging_numba_util.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=self.w_tilde_data, - data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, - data_weights=linear_obj.unique_mappings.data_weights, - pix_lengths=linear_obj.unique_mappings.pix_lengths, - pix_pixels=linear_obj.params, + data_vector_list = [] + + for mapper in self.cls_list_from(cls=AbstractMapper): + + rows, cols, vals = mapper.sparse_triplets_data + + data_vector_mapper = ( + inversion_imaging_numba_util.data_vector_via_psf_weighted_data_from( + psf_weighted_data=self.psf_weighted_data, + rows=rows, + cols=cols, + vals=vals, + S=mapper.params, ) - for linear_obj in self.linear_obj_list - ] - ) + ) + + data_vector_list.append(data_vector_mapper) + + return np.concatenate(data_vector_list) @property def _data_vector_func_list_and_mapper(self) -> np.ndarray: @@ -207,7 +206,7 @@ def _data_vector_func_list_and_mapper(self) -> np.ndarray: param_range = linear_func_param_range[linear_func_index] - data_vector[param_range[0] : param_range[1],] = diag + data_vector[param_range[0] : param_range[1]] = diag return data_vector @@ -220,8 +219,8 @@ def curvature_matrix(self) -> np.ndarray: The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf, where the curvature matrix given by equation (4) and the letter F. - This function computes F using the w_tilde formalism, which is faster as it precomputes the PSF convolution - of different noise-map pixels (see `curvature_matrix_via_w_tilde_curvature_preload_imaging_from`). + This function computes F using the sparse_operator formalism, which is faster as it precomputes the PSF convolution + of different noise-map pixels (see `curvature_matrix_via_sparse_operator_from`). If there are multiple linear objects the curvature_matrices are combined to ensure their values are solved for simultaneously. In the w-tilde formalism this requires us to consider the mappings between data and every @@ -240,7 +239,7 @@ def curvature_matrix(self) -> np.ndarray: curvature_matrix = self._curvature_matrix_multi_mapper curvature_matrix = inversion_imaging_numba_util.curvature_matrix_mirrored_from( - curvature_matrix=curvature_matrix + curvature_matrix=curvature_matrix, ) if len(self.no_regularization_index_list) > 0: @@ -277,10 +276,10 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: mapper_i = mapper_list[i] mapper_param_range_i = mapper_param_range_list[i] - diag = inversion_imaging_numba_util.curvature_matrix_via_w_tilde_curvature_preload_imaging_from( - curvature_preload=self.w_tilde.curvature_preload, - curvature_indexes=self.w_tilde.indexes, - curvature_lengths=self.w_tilde.lengths, + diag = inversion_imaging_numba_util.curvature_matrix_via_sparse_operator_from( + psf_precision_operator=self.sparse_operator.psf_precision_operator_sparse, + psf_precision_indexes=self.sparse_operator.indexes, + psf_precision_lengths=self.sparse_operator.lengths, data_to_pix_unique=np.array( mapper_i.unique_mappings.data_to_pix_unique ), @@ -309,13 +308,13 @@ def _curvature_matrix_off_diag_from( The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf, where the curvature matrix given by equation (4) and the letter F. - This function computes the off-diagonal terms of F using the w_tilde formalism. + This function computes the off-diagonal terms of F using the sparse_operator formalism. """ - curvature_matrix_off_diag_0 = inversion_imaging_numba_util.curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( - curvature_preload=self.w_tilde.curvature_preload, - curvature_indexes=self.w_tilde.indexes, - curvature_lengths=self.w_tilde.lengths, + curvature_matrix_off_diag_0 = inversion_imaging_numba_util.curvature_matrix_off_diags_via_sparse_operator_from( + psf_precision_operator=self.sparse_operator.psf_precision_operator_sparse, + psf_precision_indexes=self.sparse_operator.indexes, + psf_precision_lengths=self.sparse_operator.lengths, data_to_pix_unique_0=mapper_0.unique_mappings.data_to_pix_unique, data_weights_0=mapper_0.unique_mappings.data_weights, pix_lengths_0=mapper_0.unique_mappings.pix_lengths, @@ -326,10 +325,10 @@ def _curvature_matrix_off_diag_from( pix_pixels_1=mapper_1.params, ) - curvature_matrix_off_diag_1 = inversion_imaging_numba_util.curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( - curvature_preload=self.w_tilde.curvature_preload, - curvature_indexes=self.w_tilde.indexes, - curvature_lengths=self.w_tilde.lengths, + curvature_matrix_off_diag_1 = inversion_imaging_numba_util.curvature_matrix_off_diags_via_sparse_operator_from( + psf_precision_operator=self.sparse_operator.psf_precision_operator_sparse, + psf_precision_indexes=self.sparse_operator.indexes, + psf_precision_lengths=self.sparse_operator.lengths, data_to_pix_unique_0=mapper_1.unique_mappings.data_to_pix_unique, data_weights_0=mapper_1.unique_mappings.data_weights, pix_lengths_0=mapper_1.unique_mappings.pix_lengths, @@ -399,7 +398,7 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf, where the curvature matrix given by equation (4) and the letter F. - This function computes the diagonal terms of F using the w_tilde formalism. + This function computes the diagonal terms of F using the sparse_operator formalism. """ curvature_matrix = self._curvature_matrix_multi_mapper @@ -419,7 +418,7 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: for func_index, linear_func in enumerate(linear_func_list): linear_func_param_range = linear_func_param_range_list[func_index] - curvature_weights = ( + data_linear_func_matrix = ( self.linear_func_operated_mapping_matrix_dict[linear_func] / self.noise_map[:, None] ** 2 ) @@ -429,7 +428,7 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: data_weights=mapper.unique_mappings.data_weights, pix_lengths=mapper.unique_mappings.pix_lengths, pix_pixels=mapper.params, - curvature_weights=np.array(curvature_weights), + curvature_weights=np.array(data_linear_func_matrix), mask=self.mask.array, psf_kernel=self.psf.native.array, ) @@ -440,6 +439,7 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: ] = off_diag for index_0, linear_func_0 in enumerate(linear_func_list): + linear_func_param_range_0 = linear_func_param_range_list[index_0] weighted_vector_0 = ( @@ -503,6 +503,7 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: reconstruction = reconstruction_dict[linear_obj] if isinstance(linear_obj, AbstractMapper): + mapped_reconstructed_image = inversion_imaging_numba_util.mapped_reconstructed_data_via_image_to_pix_unique_from( data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, data_weights=linear_obj.unique_mappings.data_weights, @@ -515,7 +516,8 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: ) mapped_reconstructed_image = self.psf.convolved_image_from( - image=mapped_reconstructed_image, blurring_image=None, xp=self._xp + image=mapped_reconstructed_image, + blurring_image=None, ).array mapped_reconstructed_image = Array2D( diff --git a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py index 898cdf171..fef730bb3 100644 --- a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py +++ b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py @@ -1,9 +1,8 @@ from dataclasses import dataclass import logging import numpy as np -from tqdm import tqdm -import os import time +from pathlib import Path logger = logging.getLogger(__name__) @@ -78,7 +77,7 @@ def _report_memory(arr): pass -def w_tilde_curvature_preload_interferometer_from( +def nufft_precision_operator_from( noise_map_real: np.ndarray, uv_wavelengths: np.ndarray, shape_masked_pixels_2d, @@ -133,18 +132,19 @@ def w_tilde_curvature_preload_interferometer_from( ------------------------------------------------------------------------------- Full Description (Original Documentation) ------------------------------------------------------------------------------- - The matrix w_tilde is a matrix of dimensions [unmasked_image_pixels, unmasked_image_pixels] that encodes the - NUFFT of every pair of image pixels given the noise map. This can be used to efficiently compute the curvature - matrix via the mapping matrix, in a way that omits having to perform the NUFFT on every individual source pixel. - This provides a significant speed up for inversions of interferometer datasets with large number of visibilities. + The matrix `translation_invariant_nufft` a matrix of dimensions [unmasked_image_pixels, unmasked_image_pixels] + that encodes the NUFFT of every pair of image pixels given the noise map. This can be used to efficiently compute + the curvature matrix via the mapping matrix, in a way that omits having to perform the NUFFT on every individual + source pixel. This provides a significant speed up for inversions of interferometer datasets with large number of + visibilities. The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, making it impossible to store in memory and its use in linear algebra calculations extremely. This methods creates - a preload matrix that can compute the matrix w_tilde via an efficient preloading scheme which exploits the + a preload matrix that can compute the matrix via an efficient preloading scheme which exploits the symmetries in the NUFFT. - To compute w_tilde, one first defines a real space mask where every False entry is an unmasked pixel which is - used in the calculation, for example: + To compute `translation_invariant_nufft`, one first defines a real space mask where every False entry is an + unmasked pixel which is used in the calculation, for example: IxIxIxIxIxIxIxIxIxIxI IxIxIxIxIxIxIxIxIxIxI This is an imaging.Mask2D, where: @@ -171,7 +171,7 @@ def w_tilde_curvature_preload_interferometer_from( IxIxIxIxIxIxIxIxIxIxI IxIxIxIxIxIxIxIxIxIxI - In the standard calculation of `w_tilde` it is a matrix of + In the standard calculation of `translation_invariant_nufft` it is a matrix of dimensions [unmasked_image_pixels, unmasked_pixel_images], therefore for the example mask above it would be dimensions [9, 9]. One performs a double for loop over `unmasked_image_pixels`, using the (y,x) spatial offset between every possible pair of unmasked image pixels to precompute values that depend on the properties of the NUFFT. @@ -185,27 +185,27 @@ def w_tilde_curvature_preload_interferometer_from( - The values of pixels paired with themselves are also computed repeatedly for the standard calculation (e.g. 9 times using the mask above). - The `curvature_preload` method instead only computes each value once. To do this, it stores the preload values in a + The `nufft_precision_operator` method instead only computes each value once. To do this, it stores the preload values in a matrix of dimensions [shape_masked_pixels_y, shape_masked_pixels_x, 2], where `shape_masked_pixels` is the (y,x) size of the vertical and horizontal extent of unmasked pixels, e.g. the spatial extent over which the real space grid extends. - Each entry in the matrix `curvature_preload[:,:,0]` provides the the precomputed NUFFT value mapping an image pixel + Each entry in the matrix `nufft_precision_operator[:,:,0]` provides the the precomputed NUFFT value mapping an image pixel to a pixel offset by that much in the y and x directions, for example: - - curvature_preload[0,0,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and + - nufft_precision_operator[0,0,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and in the x direction by 0 - the values of pixels paired with themselves. - - curvature_preload[1,0,0] gives the precomputed values of image pixels that are offset in the y direction by 1 and + - nufft_precision_operator[1,0,0] gives the precomputed values of image pixels that are offset in the y direction by 1 and in the x direction by 0 - the values of pixel pairs [0,3], [1,4], [2,5], [3,6], [4,7] and [5,8] - - curvature_preload[0,1,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and + - nufft_precision_operator[0,1,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and in the x direction by 1 - the values of pixel pairs [0,1], [1,2], [3,4], [4,5], [6,7] and [7,9]. Flipped pairs: The above preloaded values pair all image pixel NUFFT values when a pixel is to the right and / or down of the first image pixel. However, one must also precompute pairs where the paired pixel is to the left of the host - pixels. These pairings are stored in `curvature_preload[:,:,1]`, and the ordering of these pairings is flipped in the - x direction to make it straight forward to use this matrix when computing w_tilde. + pixels. These pairings are stored in `nufft_precision_operator[:,:,1]`, and the ordering of these pairings is flipped in the + x direction to make it straight forward to use this matrix when computing the nufft weighted noise. Notes ----- @@ -227,7 +227,7 @@ def w_tilde_curvature_preload_interferometer_from( Fourier transformed is computed. """ if use_jax: - return w_tilde_curvature_preload_interferometer_via_jax_from( + return nufft_precision_operator_via_jax_from( noise_map_real=noise_map_real, uv_wavelengths=uv_wavelengths, shape_masked_pixels_2d=shape_masked_pixels_2d, @@ -235,7 +235,7 @@ def w_tilde_curvature_preload_interferometer_from( chunk_k=chunk_k, ) - return w_tilde_curvature_preload_interferometer_via_np_from( + return nufft_precision_operator_via_np_from( noise_map_real=noise_map_real, uv_wavelengths=uv_wavelengths, shape_masked_pixels_2d=shape_masked_pixels_2d, @@ -246,7 +246,7 @@ def w_tilde_curvature_preload_interferometer_from( ) -def w_tilde_curvature_preload_interferometer_via_np_from( +def nufft_precision_operator_via_np_from( noise_map_real: np.ndarray, uv_wavelengths: np.ndarray, shape_masked_pixels_2d, @@ -259,7 +259,7 @@ def w_tilde_curvature_preload_interferometer_via_np_from( """ NumPy/CPU implementation of the interferometer W-tilde curvature preload. - See `w_tilde_curvature_preload_interferometer_from` for full description. + See `nufft_precision_operator_from` for full description. """ if chunk_k <= 0: raise ValueError("chunk_k must be a positive integer") @@ -280,7 +280,9 @@ def w_tilde_curvature_preload_interferometer_via_np_from( ku = 2.0 * np.pi * uv_wavelengths[:, 0] kv = 2.0 * np.pi * uv_wavelengths[:, 1] - out = np.zeros((2 * y_shape, 2 * x_shape), dtype=np.float64) + translation_invariant_kernel = np.zeros( + (2 * y_shape, 2 * x_shape), dtype=np.float64 + ) # Corner coordinates y00, x00 = gy[0, 0], gx[0, 0] @@ -332,36 +334,38 @@ def accum_from_corner_np(y_ref, x_ref, gy_block, gx_block): # ----------------------------- # Main quadrant (+,+) # ----------------------------- - out[:y_shape, :x_shape] = accum_from_corner_np(y00, x00, gy, gx) + translation_invariant_kernel[:y_shape, :x_shape] = accum_from_corner_np( + y00, x00, gy, gx + ) # ----------------------------- # Flip in x (+,-) # ----------------------------- if x_shape > 1: block = accum_from_corner_np(y0m, x0m, gy[:, ::-1], gx[:, ::-1]) - out[:y_shape, -1:-(x_shape):-1] = block[:, 1:] + translation_invariant_kernel[:y_shape, -1:-(x_shape):-1] = block[:, 1:] # ----------------------------- # Flip in y (-,+) # ----------------------------- if y_shape > 1: block = accum_from_corner_np(ym0, xm0, gy[::-1, :], gx[::-1, :]) - out[-1:-(y_shape):-1, :x_shape] = block[1:, :] + translation_invariant_kernel[-1:-(y_shape):-1, :x_shape] = block[1:, :] # ----------------------------- # Flip in x and y (-,-) # ----------------------------- if (y_shape > 1) and (x_shape > 1): block = accum_from_corner_np(ymm, xmm, gy[::-1, ::-1], gx[::-1, ::-1]) - out[-1:-(y_shape):-1, -1:-(x_shape):-1] = block[1:, 1:] + translation_invariant_kernel[-1:-(y_shape):-1, -1:-(x_shape):-1] = block[1:, 1:] if pbar is not None: pbar.close() - return out + return translation_invariant_kernel -def w_tilde_curvature_preload_interferometer_via_jax_from( +def nufft_precision_operator_via_jax_from( noise_map_real: np.ndarray, uv_wavelengths: np.ndarray, shape_masked_pixels_2d, @@ -377,7 +381,7 @@ def w_tilde_curvature_preload_interferometer_via_jax_from( - uses a compiled for-loop (lax.fori_loop) over fixed-size visibility chunks - does not support progress bars or memory reporting (those require Python loops) - See `w_tilde_curvature_preload_interferometer_from` for full description. + See `nufft_precision_operator_from` for full description. """ import jax import jax.numpy as jnp @@ -477,24 +481,27 @@ def body(i, acc_): ) t0 = time.time() - out = _compute_all_quadrants_jit(gy, gx, chunk_k=chunk_k) - out.block_until_ready() # ensure timing includes actual device execution + translation_invariant_kernel = _compute_all_quadrants_jit(gy, gx, chunk_k=chunk_k) + translation_invariant_kernel.block_until_ready() # ensure timing includes actual device execution t1 = time.time() logger.info("INTERFEROMETER - Finished W-Tilde (JAX) in %.3f seconds", (t1 - t0)) - return np.asarray(out) + return np.asarray(translation_invariant_kernel) -def w_tilde_via_preload_from(curvature_preload, native_index_for_slim_index): +def nufft_weighted_noise_via_sparse_operator_from( + translation_invariant_kernel, native_index_for_slim_index +): """ - Use the preloaded w_tilde matrix (see `curvature_preload_interferometer_from`) to compute - w_tilde (see `w_tilde_interferometer_from`) efficiently. + Use the `translation_invariant_kernel` (see `nufft_precision_operator_from`) to compute + the `nufft_weighted_noise` efficiently. Parameters ---------- - curvature_preload - The preloaded values of the NUFFT that enable efficient computation of w_tilde. + translation_invariant_kernel + The preloaded translation invariant values of the NUFFT that enable efficient computation of the + NUFFT weighted noise matrix. native_index_for_slim_index An array of shape [total_unmasked_pixels*sub_size] that maps every unmasked sub-pixel to its corresponding native 2D pixel using its (y,x) pixel indexes. @@ -508,7 +515,7 @@ def w_tilde_via_preload_from(curvature_preload, native_index_for_slim_index): slim_size = len(native_index_for_slim_index) - w_tilde_via_preload = np.zeros((slim_size, slim_size)) + nufft_weighted_noise = np.zeros((slim_size, slim_size)) for i in range(slim_size): i_y, i_x = native_index_for_slim_index[i] @@ -519,167 +526,314 @@ def w_tilde_via_preload_from(curvature_preload, native_index_for_slim_index): y_diff = j_y - i_y x_diff = j_x - i_x - w_tilde_via_preload[i, j] = curvature_preload[y_diff, x_diff] + nufft_weighted_noise[i, j] = translation_invariant_kernel[y_diff, x_diff] for i in range(slim_size): for j in range(i, slim_size): - w_tilde_via_preload[j, i] = w_tilde_via_preload[i, j] + nufft_weighted_noise[j, i] = nufft_weighted_noise[i, j] - return w_tilde_via_preload + return nufft_weighted_noise @dataclass(frozen=True) -class WTildeFFTState: +class InterferometerSparseOperator: """ Fully static FFT / geometry state for W~ curvature. Safe to cache as long as: - - curvature_preload is fixed + - nufft_precision_operator is fixed - mask / rectangle definition is fixed - dtype is fixed - batch_size is fixed """ + dirty_image: np.ndarray y_shape: int x_shape: int M: int batch_size: int w_dtype: "jax.numpy.dtype" Khat: "jax.Array" # (2y, 2x), complex + """ + Cached FFT operator state for fast interferometer curvature-matrix assembly. + + This class packages *static* quantities needed to apply the interferometer + W~ operator efficiently using FFTs, so that repeated likelihood evaluations + do not redo expensive precomputation. + + Conceptually, the interferometer W~ operator is a translationally-invariant + linear operator on a rectangular real-space grid, constructed from the + `nufft_precision_operator` (a 2D array of correlation values on pixel offsets). + By taking an FFT of this preload, the operator can be applied to batches of + images via elementwise multiplication in Fourier space: + + apply_W(F) = IFFT( FFT(F_pad) * Khat ) + + where `F_pad` is a (2y, 2x) padded version of `F` and `Khat = FFT(nufft_precision_operator)`. + + The curvature matrix for a pixelization (mapper) is then assembled from sparse + mapping triplets without forming dense mapping matrices: + + C = A^T W A + + where A is the sparse mapping from source pixels to image pixels. + + Caching / validity + ------------------ + Instances are safe to cache and reuse as long as all of the following remain fixed: + + - `nufft_precision_operator` (hence `Khat`) + - the definition of the rectangular FFT grid (y_shape, x_shape) + - dtype / precision (float32 vs float64) + - `batch_size` + + Parameters stored + ----------------- + dirty_image + Convenience field for associated dirty image data (not used directly in + curvature assembly in this method). Stored as a NumPy array to match + upstream interfaces. + y_shape, x_shape + Shape of the *rectangular* real-space grid (not the masked slim grid). + M + Number of rectangular pixels, M = y_shape * x_shape. + batch_size + Number of source-pixel columns assembled and operated on per block. + Larger batch sizes improve throughput on GPU but increase memory usage. + w_dtype + Floating-point dtype for weights and accumulations (e.g. float64). + Khat + FFT of the curvature preload, shape (2y_shape, 2x_shape), complex. + This is the frequency-domain representation of the W~ operator kernel. + """ + @classmethod + def from_nufft_precision_operator( + cls, + nufft_precision_operator: np.ndarray, + dirty_image: np.ndarray, + *, + batch_size: int = 128, + ): + """ + Construct an `InterferometerSparseOperator` from a curvature-preload array. + + This is the standard factory used in interferometer inversions. + + The curvature preload is assumed to be defined on a (2y, 2x) rectangular + grid of pixel offsets, where y and x correspond to the *unmasked extent* + of the real-space grid. The preload is FFT'd once to obtain `Khat`, which + is then reused for every subsequent curvature matrix build. + + Parameters + ---------- + nufft_precision_operator + Real-valued array of shape (2y, 2x) encoding the W~ operator in real + space as a function of pixel offsets. The shape must be even in both + axes so that y_shape = H2//2 and x_shape = W2//2 are integers. + dirty_image + The dirty image associated with the dataset (or any convenient + reference image). Not required for curvature computation itself, + but commonly stored alongside the state for debugging / plotting. + batch_size + Number of source-pixel columns processed per block when assembling + the curvature matrix. Higher values typically improve GPU efficiency + but increase intermediate memory usage. + + Returns + ------- + InterferometerSparseOperator + Immutable cached state object containing shapes and FFT kernel `Khat`. + + Raises + ------ + ValueError + If `nufft_precision_operator` does not have even shape in both dimensions. + """ + import jax.numpy as jnp + + H2, W2 = nufft_precision_operator.shape + if (H2 % 2) != 0 or (W2 % 2) != 0: + raise ValueError( + f"nufft_precision_operator must have even shape (2y,2x). Got {nufft_precision_operator.shape}." + ) + + y_shape = H2 // 2 + x_shape = W2 // 2 + M = y_shape * x_shape + + Khat = jnp.fft.fft2(nufft_precision_operator) + + return InterferometerSparseOperator( + dirty_image=dirty_image, + y_shape=y_shape, + x_shape=x_shape, + M=M, + batch_size=int(batch_size), + w_dtype=nufft_precision_operator.dtype, + Khat=Khat, + ) -def w_tilde_fft_state_from( - curvature_preload: np.ndarray, - *, - batch_size: int = 128, -) -> WTildeFFTState: - import jax.numpy as jnp + def curvature_matrix_via_sparse_operator_from( + self, + pix_indexes_for_sub_slim_index: np.ndarray, + pix_weights_for_sub_slim_index: np.ndarray, + pix_pixels: int, + fft_index_for_masked_pixel: np.ndarray, + ): + """ + Assemble the curvature matrix C = Aᵀ W A using sparse triplets and the FFT W~ operator. + + This method computes the mapper (pixelization) curvature matrix without + forming a dense mapping matrix. Instead, it uses fixed-length mapping + arrays (pixel indexes + weights per masked pixel) which define a sparse + mapping operator A in COO-like form. + + Algorithm outline + ----------------- + Let S be the number of source pixels and M be the number of rectangular + real-space pixels. + + 1) Build a fixed-length COO stream from the mapping arrays: + rows_rect[k] : rectangular pixel index (0..M-1) + cols[k] : source pixel index (0..S-1) + vals[k] : mapping weight + Invalid mappings (cols < 0 or cols >= S) are masked out. + + 2) Process source-pixel columns in blocks of width `batch_size`: + - Scatter the block’s source columns into a dense (M, batch_size) array F. + - Apply the W~ operator by FFT: + G = apply_W(F) + - Project back with Aᵀ via segmented reductions: + C[:, start:start+B] = Aᵀ G + + 3) Symmetrize the result: + C <- 0.5 * (C + Cᵀ) + + Parameters + ---------- + pix_indexes_for_sub_slim_index + Integer array of shape (M_masked, Pmax). + For each masked (slim) image pixel, stores the source-pixel indices + involved in the interpolation / mapping stencil. Invalid entries + should be set to -1. + pix_weights_for_sub_slim_index + Floating array of shape (M_masked, Pmax). + Weights corresponding to `pix_indexes_for_sub_slim_index`. + These should already include any oversampling normalisation (e.g. + sub-pixel fractions) required by the mapper. + pix_pixels + Number of source pixels, S. + fft_index_for_masked_pixel + Integer array of shape (M_masked,). + Maps each masked (slim) image pixel index to its corresponding + rectangular-grid flat index (0..M-1). This embeds the masked pixel + ordering into the FFT-friendly rectangular grid. + + Returns + ------- + jax.Array + Curvature matrix of shape (S, S), symmetric. + + Notes + ----- + - The inner computation is written in JAX and is intended to be jitted. + For best performance, keep `batch_size` fixed (static) across calls. + - Choosing `batch_size` as a divisor of S avoids a smaller tail block, + but correctness does not require that if the implementation masks the tail. + - This method uses FFTs on padded (2y, 2x) arrays; memory use scales with + batch_size and grid size. + """ + + import jax.numpy as jnp + from jax.ops import segment_sum + + # ------------------------- + # Pull static quantities from state + # ------------------------- + y_shape = self.y_shape + x_shape = self.x_shape + M = self.M + batch_size = self.batch_size + Khat = self.Khat + w_dtype = self.w_dtype + + # ------------------------- + # Basic shape checks (NumPy side, safe) + # ------------------------- + M_masked, Pmax = pix_indexes_for_sub_slim_index.shape + S = int(pix_pixels) + + # ------------------------- + # JAX core (unchanged COO logic) + # ------------------------- + def _curvature_rect_jax( + pix_idx: jnp.ndarray, # (M_masked, Pmax) + pix_wts: jnp.ndarray, # (M_masked, Pmax) + rect_map: jnp.ndarray, # (M_masked,) + ) -> jnp.ndarray: + rect_map = jnp.asarray(rect_map) + + nnz_full = M_masked * Pmax + + # Flatten mapping arrays into a fixed-length COO stream + rows_mask = jnp.repeat( + jnp.arange(M_masked, dtype=jnp.int32), Pmax + ) # (nnz_full,) + cols = pix_idx.reshape((nnz_full,)).astype(jnp.int32) + vals = pix_wts.reshape((nnz_full,)).astype(w_dtype) + + # Validity mask + valid = (cols >= 0) & (cols < S) + + # Embed masked rows into rectangular rows + rows_rect = rect_map[rows_mask].astype(jnp.int32) + + # Make cols / vals safe + cols_safe = jnp.where(valid, cols, 0) + vals_safe = jnp.where(valid, vals, 0.0) + + def apply_operator_fft_batch(Fbatch_flat: jnp.ndarray) -> jnp.ndarray: + B = Fbatch_flat.shape[1] + F_img = Fbatch_flat.T.reshape((B, y_shape, x_shape)) + F_pad = jnp.pad( + F_img, ((0, 0), (0, y_shape), (0, x_shape)) + ) # (B,2y,2x) + Fhat = jnp.fft.fft2(F_pad) + Ghat = Fhat * Khat[None, :, :] + G_pad = jnp.fft.ifft2(Ghat) + G = jnp.real(G_pad[:, :y_shape, :x_shape]) + return G.reshape((B, M)).T # (M,B) + + def compute_block(start_col: int) -> jnp.ndarray: + in_block = (cols_safe >= start_col) & ( + cols_safe < start_col + batch_size + ) + in_use = valid & in_block - H2, W2 = curvature_preload.shape - if (H2 % 2) != 0 or (W2 % 2) != 0: - raise ValueError( - f"curvature_preload must have even shape (2y,2x). Got {curvature_preload.shape}." - ) + bc = jnp.where(in_use, cols_safe - start_col, 0).astype(jnp.int32) + v = jnp.where(in_use, vals_safe, 0.0) - y_shape = H2 // 2 - x_shape = W2 // 2 - M = y_shape * x_shape + Fbatch = jnp.zeros((M, batch_size), dtype=w_dtype) + Fbatch = Fbatch.at[rows_rect, bc].add(v) - Khat = jnp.fft.fft2(curvature_preload) + Gbatch = apply_operator_fft_batch(Fbatch) + G_at_rows = Gbatch[rows_rect, :] - return WTildeFFTState( - y_shape=y_shape, - x_shape=x_shape, - M=M, - batch_size=int(batch_size), - w_dtype=curvature_preload.dtype, - Khat=Khat, - ) + contrib = vals_safe[:, None] * G_at_rows + return segment_sum(contrib, cols_safe, num_segments=S) + # Assemble curvature + C = jnp.zeros((S, S), dtype=w_dtype) + for start in range(0, S, batch_size): + Cblock = compute_block(start) + width = min(batch_size, S - start) + C = C.at[:, start : start + width].set(Cblock[:, :width]) -def curvature_matrix_via_w_tilde_interferometer_from( - *, - fft_state: WTildeFFTState, - pix_indexes_for_sub_slim_index: np.ndarray, - pix_weights_for_sub_slim_index: np.ndarray, - pix_pixels: int, - rect_index_for_mask_index: np.ndarray, -): - """ - Compute curvature matrix for an interferometer inversion using a precomputed FFT state. + return 0.5 * (C + C.T) - IMPORTANT - --------- - - COO construction is unchanged from the known-working implementation - - Only FFT- and geometry-related quantities are taken from `fft_state` - """ - import jax - import jax.numpy as jnp - from jax.ops import segment_sum - - # ------------------------- - # Pull static quantities from state - # ------------------------- - y_shape = fft_state.y_shape - x_shape = fft_state.x_shape - M = fft_state.M - batch_size = fft_state.batch_size - Khat = fft_state.Khat - w_dtype = fft_state.w_dtype - - # ------------------------- - # Basic shape checks (NumPy side, safe) - # ------------------------- - M_masked, Pmax = pix_indexes_for_sub_slim_index.shape - S = int(pix_pixels) - - # ------------------------- - # JAX core (unchanged COO logic) - # ------------------------- - def _curvature_rect_jax( - pix_idx: jnp.ndarray, # (M_masked, Pmax) - pix_wts: jnp.ndarray, # (M_masked, Pmax) - rect_map: jnp.ndarray, # (M_masked,) - ) -> jnp.ndarray: - - rect_map = jnp.asarray(rect_map) - - nnz_full = M_masked * Pmax - - # Flatten mapping arrays into a fixed-length COO stream - rows_mask = jnp.repeat( - jnp.arange(M_masked, dtype=jnp.int32), Pmax - ) # (nnz_full,) - cols = pix_idx.reshape((nnz_full,)).astype(jnp.int32) - vals = pix_wts.reshape((nnz_full,)).astype(w_dtype) - - # Validity mask - valid = (cols >= 0) & (cols < S) - - # Embed masked rows into rectangular rows - rows_rect = rect_map[rows_mask].astype(jnp.int32) - - # Make cols / vals safe - cols_safe = jnp.where(valid, cols, 0) - vals_safe = jnp.where(valid, vals, 0.0) - - def apply_W_fft_batch(Fbatch_flat: jnp.ndarray) -> jnp.ndarray: - B = Fbatch_flat.shape[1] - F_img = Fbatch_flat.T.reshape((B, y_shape, x_shape)) - F_pad = jnp.pad(F_img, ((0, 0), (0, y_shape), (0, x_shape))) # (B,2y,2x) - Fhat = jnp.fft.fft2(F_pad) - Ghat = Fhat * Khat[None, :, :] - G_pad = jnp.fft.ifft2(Ghat) - G = jnp.real(G_pad[:, :y_shape, :x_shape]) - return G.reshape((B, M)).T # (M,B) - - def compute_block(start_col: int) -> jnp.ndarray: - in_block = (cols_safe >= start_col) & (cols_safe < start_col + batch_size) - in_use = valid & in_block - - bc = jnp.where(in_use, cols_safe - start_col, 0).astype(jnp.int32) - v = jnp.where(in_use, vals_safe, 0.0) - - Fbatch = jnp.zeros((M, batch_size), dtype=w_dtype) - Fbatch = Fbatch.at[rows_rect, bc].add(v) - - Gbatch = apply_W_fft_batch(Fbatch) - G_at_rows = Gbatch[rows_rect, :] - - contrib = vals_safe[:, None] * G_at_rows - return segment_sum(contrib, cols_safe, num_segments=S) - - # Assemble curvature - C = jnp.zeros((S, S), dtype=w_dtype) - for start in range(0, S, batch_size): - Cblock = compute_block(start) - width = min(batch_size, S - start) - C = C.at[:, start : start + width].set(Cblock[:, :width]) - - return 0.5 * (C + C.T) - - return _curvature_rect_jax( - pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index, - rect_index_for_mask_index, - ) + return _curvature_rect_jax( + pix_indexes_for_sub_slim_index, + pix_weights_for_sub_slim_index, + fft_index_for_masked_pixel, + ) diff --git a/autoarray/inversion/inversion/interferometer/w_tilde.py b/autoarray/inversion/inversion/interferometer/sparse.py similarity index 87% rename from autoarray/inversion/inversion/interferometer/w_tilde.py rename to autoarray/inversion/inversion/interferometer/sparse.py index 7e1eb4ff8..20b3746ec 100644 --- a/autoarray/inversion/inversion/interferometer/w_tilde.py +++ b/autoarray/inversion/inversion/interferometer/sparse.py @@ -6,7 +6,6 @@ from autoarray.inversion.inversion.interferometer.abstract import ( AbstractInversionInterferometer, ) -from autoarray.dataset.interferometer.w_tilde import WTildeInterferometer from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper @@ -15,14 +14,10 @@ from autoarray.inversion.inversion.interferometer import inversion_interferometer_util -from autoarray import exc - - -class InversionInterferometerWTilde(AbstractInversionInterferometer): +class InversionInterferometerSparse(AbstractInversionInterferometer): def __init__( self, dataset: Union[Interferometer, DatasetInterface], - w_tilde: WTildeInterferometer, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), xp=np, @@ -46,15 +41,10 @@ def __init__( transformer The transformer which performs a non-uniform fast Fourier transform operations on the mapping matrix with the interferometer data's transformer. - w_tilde - An object containing matrices that construct the linear equations via the w-tilde formalism which bypasses - the mapping matrix. linear_obj_list The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed the simultaneous linear equations are combined and solved simultaneously. """ - self.w_tilde = w_tilde - super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, @@ -76,9 +66,11 @@ def data_vector(self) -> np.ndarray: If there are multiple linear objects the `data_vectors` are concatenated ensuring their values are solved for simultaneously. - The calculation is described in more detail in `inversion_util.w_tilde_data_interferometer_from`. + The calculation is described in more detail in `inversion_util.weighted_data_interferometer_from`. """ - return self._xp.dot(self.mapping_matrix.T, self.w_tilde.dirty_image) + return self._xp.dot( + self.mapping_matrix.T, self.dataset.sparse_operator.dirty_image + ) @property def curvature_matrix(self) -> np.ndarray: @@ -104,17 +96,16 @@ def curvature_matrix_diag(self) -> np.ndarray: The linear algebra is described in the paper https://arxiv.org/pdf/astro-ph/0302587.pdf, where the curvature matrix given by equation (4) and the letter F. - This function computes the diagonal terms of F using the w_tilde formalism. + This function computes the diagonal terms of F using the sparse linear algebra formalism. """ mapper = self.cls_list_from(cls=AbstractMapper)[0] - return inversion_interferometer_util.curvature_matrix_via_w_tilde_interferometer_from( - fft_state=self.w_tilde.fft_state, + return self.dataset.sparse_operator.curvature_matrix_via_sparse_operator_from( pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, pix_pixels=self.linear_obj_list[0].params, - rect_index_for_mask_index=self.w_tilde.rect_index_for_mask_index, + fft_index_for_masked_pixel=self.mask.fft_index_for_masked_pixel, ) @property diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 2190bf81c..854c331bd 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -8,20 +8,20 @@ from autoarray.util.fnnls import fnnls_cholesky -def curvature_matrix_via_w_tilde_from( - w_tilde: np.ndarray, mapping_matrix: np.ndarray, xp=np +def curvature_matrix_diag_via_psf_weighted_noise_from( + psf_weighted_noise: np.ndarray, mapping_matrix: np.ndarray, xp=np ) -> np.ndarray: """ - Returns the curvature matrix `F` (see Warren & Dye 2003) from `w_tilde`. + Returns the curvature matrix `F` (see Warren & Dye 2003) from the `psf_weighted_noise`. - The dimensions of `w_tilde` are [image_pixels, image_pixels], meaning that for datasets with many image pixels - this matrix can take up 10's of GB of memory. The calculation of the `curvature_matrix` via this function will - therefore be very slow, and the method `curvature_matrix_via_w_tilde_curvature_preload_imaging_from` should be used + The dimensions of `psf_weighted_noise` are [image_pixels, image_pixels], meaning that for datasets with many image + pixels this matrix can take up 10's of GB of memory. The calculation of the `curvature_matrix` via this function + will therefore be very slow, and the method `curvature_matrix_diag_via_sparse_operator_from` should be used instead. Parameters ---------- - w_tilde + psf_weighted_noise A matrix of dimensions [image_pixels, image_pixels] that encodes the convolution or NUFFT of every image pixel pair on the noise map. mapping_matrix @@ -32,7 +32,7 @@ def curvature_matrix_via_w_tilde_from( ndarray The curvature matrix `F` (see Warren & Dye 2003). """ - return xp.dot(mapping_matrix.T, xp.dot(w_tilde, mapping_matrix)) + return xp.dot(mapping_matrix.T, xp.dot(psf_weighted_noise, mapping_matrix)) def curvature_matrix_with_added_to_diag_from( @@ -126,12 +126,14 @@ def mapped_reconstructed_data_via_mapping_matrix_from( return xp.dot(mapping_matrix, reconstruction) -def mapped_reconstructed_data_via_w_tilde_from( - w_tilde: np.ndarray, mapping_matrix: np.ndarray, reconstruction: np.ndarray +def mapped_reconstructed_data_via_psf_weighted_noise_from( + psf_weighted_noise: np.ndarray, + mapping_matrix: np.ndarray, + reconstruction: np.ndarray, ) -> np.ndarray: """ Returns the reconstructed data vector from the unblurred mapping matrix `M`, - the reconstruction vector `s`, and the PSF convolution operator `w_tilde`. + the reconstruction vector `s`, and the PSF convolution operator `psf_weighted_noise`. Equivalent to: reconstructed = (W @ M) @ s @@ -139,7 +141,7 @@ def mapped_reconstructed_data_via_w_tilde_from( Parameters ---------- - w_tilde + psf_weighted_noise Array of shape [image_pixels, image_pixels], the PSF convolution operator. mapping_matrix Array of shape [image_pixels, source_pixels], unblurred mapping matrix. @@ -151,7 +153,7 @@ def mapped_reconstructed_data_via_w_tilde_from( ndarray The reconstructed data vector of shape [image_pixels]. """ - return w_tilde @ (mapping_matrix @ reconstruction) + return psf_weighted_noise @ (mapping_matrix @ reconstruction) def reconstruction_positive_negative_from( diff --git a/autoarray/inversion/inversion/settings.py b/autoarray/inversion/inversion/settings.py index 4798e4e04..73494226d 100644 --- a/autoarray/inversion/inversion/settings.py +++ b/autoarray/inversion/inversion/settings.py @@ -14,8 +14,6 @@ def __init__( positive_only_uses_p_initial: Optional[bool] = None, use_border_relocator: Optional[bool] = None, no_regularization_add_to_curvature_diag_value: float = None, - use_w_tilde_numpy: bool = False, - use_source_loop: bool = False, tolerance: float = 1e-8, maxiter: int = 250, ): @@ -36,11 +34,6 @@ def __init__( no_regularization_add_to_curvature_diag_value If a linear func object does not have a corresponding regularization, this value is added to its diagonal entries of the curvature regularization matrix to ensure the matrix is positive-definite. - use_w_tilde_numpy - If True, the curvature_matrix is computed via numpy matrix multiplication (as opposed to numba functions - which exploit sparsity to do the calculation normally in a more efficient way). - use_source_loop - Shhhh its a secret. tolerance For an interferometer inversion using the linear operators method, sets the tolerance of the solver (this input does nothing for dataset data and other interferometer methods). @@ -57,8 +50,6 @@ def __init__( self.tolerance = tolerance self.maxiter = maxiter - self.use_w_tilde_numpy = use_w_tilde_numpy - self.use_source_loop = use_source_loop @property def use_positive_only_solver(self): diff --git a/autoarray/inversion/mock/mock_inversion_imaging.py b/autoarray/inversion/mock/mock_inversion_imaging.py index 4418ba398..283d3152f 100644 --- a/autoarray/inversion/mock/mock_inversion_imaging.py +++ b/autoarray/inversion/mock/mock_inversion_imaging.py @@ -3,7 +3,6 @@ from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.inversion.imaging.mapping import InversionImagingMapping -from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde from autoarray.inversion.inversion.settings import SettingsInversion @@ -70,43 +69,3 @@ def data_linear_func_matrix_dict(self) -> Dict: return super().data_linear_func_matrix_dict return self._data_linear_func_matrix_dict - - -class MockWTildeImaging: - pass - - -class MockInversionImagingWTilde(InversionImagingWTilde): - def __init__( - self, - data=None, - noise_map=None, - psf=None, - w_tilde=None, - linear_obj_list=None, - curvature_matrix_mapper_diag=None, - settings: SettingsInversion = None, - ): - dataset = DatasetInterface( - data=data, - noise_map=noise_map, - psf=psf, - ) - - settings = settings or SettingsInversion() - - super().__init__( - dataset=dataset, - w_tilde=w_tilde or MockWTildeImaging(), - linear_obj_list=linear_obj_list, - settings=settings, - ) - - self.__curvature_matrix_mapper_diag = curvature_matrix_mapper_diag - - @property - def curvature_matrix_mapper_diag(self): - if self.__curvature_matrix_mapper_diag is None: - return super()._curvature_matrix_mapper_diag - - return self.__curvature_matrix_mapper_diag diff --git a/autoarray/inversion/pixelization/border_relocator.py b/autoarray/inversion/pixelization/border_relocator.py index 05b825fc0..b5d9fe873 100644 --- a/autoarray/inversion/pixelization/border_relocator.py +++ b/autoarray/inversion/pixelization/border_relocator.py @@ -269,7 +269,10 @@ def relocated_grid_from(grid, border_grid, xp=np): class BorderRelocator: def __init__( - self, mask: Mask2D, sub_size: Union[int, Array2D], use_w_tilde: bool = False + self, + mask: Mask2D, + sub_size: Union[int, Array2D], + use_sparse_operator: bool = False, ): """ Relocates source plane coordinates that trace outside the mask’s border in the source-plane back onto the @@ -327,7 +330,7 @@ def __init__( self.sub_border_grid = sub_grid[self.sub_border_slim] - self.use_w_tilde = use_w_tilde + self.use_sparse_operator = use_sparse_operator def relocated_grid_from(self, grid: Grid2D, xp=np) -> Grid2D: """ @@ -356,7 +359,7 @@ def relocated_grid_from(self, grid: Grid2D, xp=np) -> Grid2D: if len(self.sub_border_grid) == 0: return grid - if self.use_w_tilde is False or xp.__name__.startswith("jax"): + if self.use_sparse_operator is False or xp.__name__.startswith("jax"): values = relocated_grid_from( grid=grid.array, border_grid=grid.array[self.border_slim], xp=xp @@ -370,7 +373,7 @@ def relocated_grid_from(self, grid: Grid2D, xp=np) -> Grid2D: else: - from autoarray.inversion.inversion.imaging import ( + from autoarray.inversion.inversion.imaging_numba import ( inversion_imaging_numba_util, ) @@ -408,7 +411,7 @@ def relocated_mesh_grid_from( if len(self.sub_border_grid) == 0: return mesh_grid - if self.use_w_tilde is False or xp.__name__.startswith("jax"): + if self.use_sparse_operator is False or xp.__name__.startswith("jax"): relocated_grid = relocated_grid_from( grid=mesh_grid.array, diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index e13ec3470..8c8ed067e 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -268,6 +268,120 @@ def mapping_matrix(self) -> np.ndarray: xp=self._xp, ) + @cached_property + def sparse_triplets_data(self): + """ + Sparse triplet representation of the (unblurred) mapping operator on the *slim data grid*. + + This property returns the mapping between image-plane subpixels and source pixels in + sparse COO triplet form: + + (rows, cols, vals) + + where each triplet encodes one non-zero entry of the mapping matrix: + + A[row, col] += val + + The returned indices correspond to: + + - `rows`: slim masked image pixel indices (one per subpixel contribution) + - `cols`: source pixel indices in the pixelization + - `vals`: interpolation weights, including oversampling normalization + + This representation is used for efficient computation of quantities such as the + data vector: + + D = Aᵀ d + + without ever forming the dense mapping matrix explicitly. + + Notes + ----- + - This version keeps `rows` in *slim masked pixel coordinates*, which is the natural + indexing convention for data-vector calculations using `psf_operated_data`. + - The triplets contain only non-zero contributions, making them significantly faster + than dense matrix operations. + + Returns + ------- + rows : ndarray of shape (nnz,) + Slim masked image pixel index for each non-zero mapping entry. + + cols : ndarray of shape (nnz,) + Source pixel index for each mapping entry. + + vals : ndarray of shape (nnz,) + Mapping weight for each entry, including subpixel normalization. + """ + + rows, cols, vals = mapper_util.sparse_triplets_from( + pix_indexes_for_sub=self.pix_indexes_for_sub_slim_index, + pix_weights_for_sub=self.pix_weights_for_sub_slim_index, + slim_index_for_sub=self.slim_index_for_sub_slim_index, + fft_index_for_masked_pixel=self.mapper_grids.mask.fft_index_for_masked_pixel, + sub_fraction_slim=self.over_sampler.sub_fraction.array, + xp=self._xp, + ) + + return rows, cols, vals + + @cached_property + def sparse_triplets_curvature(self): + """ + Sparse triplet representation of the mapping operator on the *rectangular FFT grid*. + + This property returns the same sparse mapping triplets as `sparse_triplets_data`, + but with the row indices converted from slim masked pixel coordinates into the + rectangular FFT indexing system used in the w-tilde curvature formalism. + + This is required because curvature matrix calculations involve applying the + PSF precision operator: + + W = Hᵀ N⁻¹ H + + via FFT-based convolution on a rectangular grid. Therefore the mapping operator + must be expressed in terms of rectangular pixel indices. + + Specifically: + + - `rows` are converted from slim masked pixel indices into FFT-grid indices via: + + rows_rect = fft_index_for_masked_pixel[rows_slim] + + The resulting triplets are used in curvature matrix assembly: + + F = Aᵀ W A + + Notes + ----- + - Use `sparse_triplets_data` for data-vector calculations. + - Use `sparse_triplets_curvature` for curvature matrix calculations with FFT-based + PSF operators. + + Returns + ------- + rows : ndarray of shape (nnz,) + Rectangular FFT-grid pixel index for each mapping entry. + + cols : ndarray of shape (nnz,) + Source pixel index for each mapping entry. + + vals : ndarray of shape (nnz,) + Mapping weight for each entry. + """ + + rows, cols, vals = mapper_util.sparse_triplets_from( + pix_indexes_for_sub=self.pix_indexes_for_sub_slim_index, + pix_weights_for_sub=self.pix_weights_for_sub_slim_index, + slim_index_for_sub=self.slim_index_for_sub_slim_index, + fft_index_for_masked_pixel=self.mapper_grids.mask.fft_index_for_masked_pixel, + sub_fraction_slim=self.over_sampler.sub_fraction.array, + xp=self._xp, + return_rows_slim=False, + ) + + return rows, cols, vals + def pixel_signals_from(self, signal_scale: float, xp=np) -> np.ndarray: """ Returns the signal in each pixelization pixel, where this signal is an estimate of the expected signal diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index 06857dc69..4474335c6 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -442,6 +442,104 @@ def adaptive_pixel_signals_from( return pixel_signals**signal_scale +def sparse_triplets_from( + pix_indexes_for_sub, # (M_sub, P) + pix_weights_for_sub, # (M_sub, P) + slim_index_for_sub, # (M_sub,) + fft_index_for_masked_pixel, # (N_unmasked,) + sub_fraction_slim, # (N_unmasked,) + *, + return_rows_slim: bool = True, + xp=np, +): + """ + Build sparse source→image mapping triplets (rows, cols, vals) + for a fixed-size interpolation stencil. + + This supports both: + - NumPy (xp=np) + - JAX (xp=jax.numpy) + + Parameters + ---------- + pix_indexes_for_sub + Source pixel indices for each subpixel (M_sub, P) + pix_weights_for_sub + Interpolation weights for each subpixel (M_sub, P) + slim_index_for_sub + Mapping subpixel -> slim image pixel index (M_sub,) + fft_index_for_masked_pixel + Mapping slim pixel -> rectangular FFT-grid pixel index (N_unmasked,) + sub_fraction_slim + Oversampling normalization per slim pixel (N_unmasked,) + xp + Backend module (np or jnp) + + Returns + ------- + rows : (nnz,) int32 + Rectangular FFT grid row index per mapping entry + cols : (nnz,) int32 + Source pixel index per mapping entry + vals : (nnz,) float64 + Mapping weight per entry including sub_fraction normalization + """ + # ---------------------------- + # NumPy path (HOST) + # ---------------------------- + if xp is np: + pix_indexes_for_sub = np.asarray(pix_indexes_for_sub, dtype=np.int32) + pix_weights_for_sub = np.asarray(pix_weights_for_sub, dtype=np.float64) + slim_index_for_sub = np.asarray(slim_index_for_sub, dtype=np.int32) + fft_index_for_masked_pixel = np.asarray( + fft_index_for_masked_pixel, dtype=np.int32 + ) + sub_fraction_slim = np.asarray(sub_fraction_slim, dtype=np.float64) + + M_sub, P = pix_indexes_for_sub.shape + + sub_ids = np.repeat(np.arange(M_sub, dtype=np.int32), P) # (M_sub*P,) + + cols = pix_indexes_for_sub.reshape(-1) # int32 + vals = pix_weights_for_sub.reshape(-1) # float64 + + slim_rows = slim_index_for_sub[sub_ids] # int32 + vals = vals * sub_fraction_slim[slim_rows] # float64 + + if return_rows_slim: + return slim_rows, cols, vals + + rows = fft_index_for_masked_pixel[slim_rows] + return rows, cols, vals + + # ---------------------------- + # JAX path (DEVICE) + # ---------------------------- + # We intentionally avoid np.asarray anywhere here. + # Assume xp is jax.numpy (or a compatible array module). + pix_indexes_for_sub = xp.asarray(pix_indexes_for_sub, dtype=xp.int32) + pix_weights_for_sub = xp.asarray(pix_weights_for_sub, dtype=xp.float64) + slim_index_for_sub = xp.asarray(slim_index_for_sub, dtype=xp.int32) + fft_index_for_masked_pixel = xp.asarray(fft_index_for_masked_pixel, dtype=xp.int32) + sub_fraction_slim = xp.asarray(sub_fraction_slim, dtype=xp.float64) + + M_sub, P = pix_indexes_for_sub.shape + + sub_ids = xp.repeat(xp.arange(M_sub, dtype=xp.int32), P) + + cols = pix_indexes_for_sub.reshape(-1) + vals = pix_weights_for_sub.reshape(-1) + + slim_rows = slim_index_for_sub[sub_ids] + vals = vals * sub_fraction_slim[slim_rows] + + if return_rows_slim: + return slim_rows, cols, vals + + rows = fft_index_for_masked_pixel[slim_rows] + return rows, cols, vals + + def mapping_matrix_from( pix_indexes_for_sub_slim_index: np.ndarray, pix_size_for_sub_slim_index: np.ndarray, diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 9be77f2e6..70ecd7cb1 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from autoconf import cached_property + from autoarray.structures.abstract_structure import Structure if TYPE_CHECKING: @@ -619,6 +621,48 @@ def from_fits( def shape_native(self) -> Tuple[int, ...]: return self.shape + @cached_property + def fft_index_for_masked_pixel(self) -> np.ndarray: + """ + Return a mapping from masked-pixel (slim) indices to flat indices + on the rectangular FFT grid. + + This array is used to translate between: + + - "masked pixel space" (a compact 1D indexing over unmasked pixels) + - the 2D rectangular grid on which FFT-based convolutions are performed + + The FFT grid is assumed to be rectangular and already suitable for FFTs + (e.g. padded and centered appropriately). Masked pixels are present on + this grid but are ignored in computations via zero-weighting. + + Returns + ------- + np.ndarray + A 1D array of shape (N_unmasked,), where element `i` gives the flat + (row-major) index into the FFT grid corresponding to the `i`-th + unmasked pixel in slim ordering. + + Notes + ----- + - The slim ordering is defined as the order returned by `np.where(~mask)`. + - The flat FFT index is computed assuming row-major (C-style) ordering: + flat_index = y * width + x + - This method is intentionally backend-agnostic and can be used by both + imaging and interferometer curvature pipelines. + """ + # Boolean mask defined on the rectangular FFT grid + mask_fft = self + + # Coordinates of unmasked pixels in the FFT grid + ys, xs = np.where(~mask_fft) + + # Width of the FFT grid (number of columns) + width = mask_fft.shape[1] + + # Convert (y, x) coordinates to flat row-major indices + return (ys * width + xs).astype(np.int32) + def trimmed_array_from(self, padded_array, image_shape) -> Array2D: """ Map a padded 1D array of values to its original 2D array, trimming all edge values. diff --git a/autoarray/mock.py b/autoarray/mock.py index d6dbe40b0..86bf91aec 100644 --- a/autoarray/mock.py +++ b/autoarray/mock.py @@ -8,7 +8,6 @@ from autoarray.inversion.mock.mock_linear_obj_func_list import MockLinearObjFuncList from autoarray.inversion.mock.mock_inversion import MockInversion from autoarray.inversion.mock.mock_inversion_imaging import MockInversionImaging -from autoarray.inversion.mock.mock_inversion_imaging import MockInversionImagingWTilde from autoarray.inversion.mock.mock_inversion_interferometer import ( MockInversionInterferometer, ) diff --git a/autoarray/util/__init__.py b/autoarray/util/__init__.py index 5342717be..afeb1ccb4 100644 --- a/autoarray/util/__init__.py +++ b/autoarray/util/__init__.py @@ -21,7 +21,7 @@ from autoarray.inversion.inversion.imaging import ( inversion_imaging_util as inversion_imaging, ) -from autoarray.inversion.inversion.imaging import ( +from autoarray.inversion.inversion.imaging_numba import ( inversion_imaging_numba_util as inversion_imaging_numba, ) from autoarray.inversion.inversion.interferometer import ( diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index 8aa6ea3c2..663e4ea39 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -193,12 +193,6 @@ def test__apply_mask(imaging_7x7, mask_2d_7x7, psf_3x3): assert type(masked_imaging_7x7.psf) == aa.Kernel2D - masked_imaging_7x7 = masked_imaging_7x7.apply_w_tilde() - - assert masked_imaging_7x7.w_tilde.curvature_preload.shape == (35,) - assert masked_imaging_7x7.w_tilde.indexes.shape == (35,) - assert masked_imaging_7x7.w_tilde.lengths.shape == (9,) - def test__apply_noise_scaling(imaging_7x7, mask_2d_7x7): masked_imaging_7x7 = imaging_7x7.apply_noise_scaling( diff --git a/test_autoarray/dataset/interferometer/test_dataset.py b/test_autoarray/dataset/interferometer/test_dataset.py index 7f46534c1..9d07382d6 100644 --- a/test_autoarray/dataset/interferometer/test_dataset.py +++ b/test_autoarray/dataset/interferometer/test_dataset.py @@ -149,55 +149,3 @@ def test__different_interferometer_without_mock_objects__customize_constructor_i assert (dataset.data == 1.0 + 1.0j * np.ones((19,))).all() assert (dataset.noise_map == 2.0 + 2.0j * np.ones((19,))).all() assert (dataset.uv_wavelengths == 3.0 * np.ones((19, 2))).all() - - -def test__curvature_preload_metadata_from( - visibilities_7, - visibilities_noise_map_7, - uv_wavelengths_7x2, - mask_2d_7x7, -): - - dataset = aa.Interferometer( - data=visibilities_7, - noise_map=visibilities_noise_map_7, - uv_wavelengths=uv_wavelengths_7x2, - real_space_mask=mask_2d_7x7, - ) - - dataset = dataset.apply_w_tilde(use_jax=False) - - file = f"{test_data_path}/curvature_preload_metadata" - - dataset.w_tilde.save_curvature_preload( - file=file, - overwrite=True, - ) - - curvature_preload = aa.load_curvature_preload_if_compatible( - file=file, real_space_mask=dataset.real_space_mask - ) - - assert curvature_preload[0, 0] == pytest.approx(1.75, 1.0e-4) - - real_space_mask_changed = np.array( - [ - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - [True, True, False, False, False, True, True], - [True, True, False, True, False, True, True], - [True, True, False, False, False, True, True], - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - ] - ) - - real_space_mask_changed = aa.Mask2D( - mask=real_space_mask_changed, pixel_scales=(1.0, 1.0) - ) - - with pytest.raises(ValueError): - - curvature_preload = aa.load_curvature_preload_if_compatible( - file=file, real_space_mask=real_space_mask_changed - ) diff --git a/test_autoarray/inversion/inversion/imaging/test_imaging.py b/test_autoarray/inversion/inversion/imaging/test_imaging.py index d4464545d..6a9c89446 100644 --- a/test_autoarray/inversion/inversion/imaging/test_imaging.py +++ b/test_autoarray/inversion/inversion/imaging/test_imaging.py @@ -1,6 +1,11 @@ import autoarray as aa -from autoarray.dataset.imaging.w_tilde import WTildeImaging -from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde + +from autoarray.inversion.inversion.imaging.inversion_imaging_util import ( + ImagingSparseOperator, +) +from autoarray.inversion.inversion.imaging.sparse import ( + InversionImagingSparse, +) from autoarray import exc diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index 82ea0d6d8..3069a890c 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -3,8 +3,8 @@ import pytest -def test__w_tilde_imaging_from(): - noise_map_2d = np.array( +def test__psf_weighted_noise_imaging_from(): + noise_map = np.array( [ [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], @@ -17,13 +17,13 @@ def test__w_tilde_imaging_from(): native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) - w_tilde = aa.util.inversion_imaging_numba.w_tilde_curvature_imaging_from( - noise_map_native=noise_map_2d, + psf_weighted_noise = aa.util.inversion_imaging_numba.psf_precision_operator_from( + noise_map_native=noise_map, kernel_native=kernel, native_index_for_slim_index=native_index_for_slim_index, ) - assert w_tilde == pytest.approx( + assert psf_weighted_noise == pytest.approx( np.array( [ [2.5, 1.625, 0.5, 0.375], @@ -36,41 +36,56 @@ def test__w_tilde_imaging_from(): ) -def test__w_tilde_data_imaging_from(): - image_2d = np.array( - [ +def test__psf_weighted_data_from(): + + mask = aa.Mask2D( + mask=[ + [True, True, True, True], + [True, False, False, True], + [True, False, False, True], + [True, True, True, True], + ], + pixel_scales=(1.0, 1.0), + ) + + data = aa.Array2D( + values=[ [0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0], - ] + ], + mask=mask, ) - noise_map_2d = np.array( - [ + noise_map = aa.Array2D( + values=[ [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0], - ] + ], + mask=mask, ) kernel = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 2.0, 0.0]]) native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) - w_tilde_data = aa.util.inversion_imaging.w_tilde_data_imaging_from( - image_native=image_2d, - noise_map_native=noise_map_2d, + weight_map = data / (noise_map**2) + weight_map = aa.Array2D(values=weight_map, mask=mask) + + psf_weighted_data = aa.util.inversion_imaging.psf_weighted_data_from( + weight_map_native=weight_map.native.array, kernel_native=kernel, native_index_for_slim_index=native_index_for_slim_index, ) - assert (w_tilde_data == np.array([5.0, 5.0, 1.5, 1.5])).all() + assert (psf_weighted_data == np.array([5.0, 5.0, 1.5, 1.5])).all() -def test__w_tilde_curvature_preload_imaging_from(): - noise_map_2d = np.array( +def test__psf_precision_operator_sparse_from(): + noise_map = np.array( [ [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], @@ -84,26 +99,26 @@ def test__w_tilde_curvature_preload_imaging_from(): native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) ( - w_tilde_preload, - w_tilde_indexes, - w_tilde_lengths, - ) = aa.util.inversion_imaging_numba.w_tilde_curvature_preload_imaging_from( - noise_map_native=noise_map_2d, + psf_weighted_noise_preload, + psf_weighted_noise_indexes, + psf_weighted_noise_lengths, + ) = aa.util.inversion_imaging_numba.psf_precision_operator_sparse_from( + noise_map_native=noise_map, kernel_native=kernel, native_index_for_slim_index=native_index_for_slim_index, ) - assert w_tilde_preload == pytest.approx( + assert psf_weighted_noise_preload == pytest.approx( np.array( [1.25, 1.625, 0.5, 0.375, 0.65625, 0.125, 0.0625, 0.25, 0.375, 0.15625] ), 1.0e-4, ) - assert w_tilde_indexes == pytest.approx( + assert psf_weighted_noise_indexes == pytest.approx( np.array([0, 1, 2, 3, 1, 2, 3, 2, 3, 3]), 1.0e-4 ) - assert w_tilde_lengths == pytest.approx(np.array([4, 3, 2, 1]), 1.0e-4) + assert psf_weighted_noise_lengths == pytest.approx(np.array([4, 3, 2, 1]), 1.0e-4) def test__data_vector_via_blurred_mapping_matrix_from(): @@ -168,7 +183,7 @@ def test__data_vector_via_blurred_mapping_matrix_from(): assert (data_vector == np.array([2.0, 3.0, 1.0])).all() -def test__data_vector_via_w_tilde_data_two_methods_agree(): +def test__data_vector_via_weighted_data_two_methods_agree(): mask = aa.Mask2D.circular(shape_native=(51, 51), pixel_scales=0.1, radius=2.0) image = np.random.uniform(size=mask.shape_native) @@ -188,6 +203,7 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): # TODO : Use pytest.parameterize for sub_size in range(1, 3): + grid = aa.Grid2D.from_mask(mask=mask, over_sample_size=sub_size) mapper_grids = pixelization.mapper_grids_from( @@ -215,44 +231,40 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): ) ) - w_tilde_data = aa.util.inversion_imaging.w_tilde_data_imaging_from( - image_native=image.native.array, - noise_map_native=noise_map.native.array, - kernel_native=kernel.native.array, - native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype( - "int" - ), + rows, cols, vals = aa.util.mapper.sparse_triplets_from( + pix_indexes_for_sub=mapper.pix_indexes_for_sub_slim_index, + pix_weights_for_sub=mapper.pix_weights_for_sub_slim_index, + slim_index_for_sub=mapper.slim_index_for_sub_slim_index, + fft_index_for_masked_pixel=mask.fft_index_for_masked_pixel, + sub_fraction_slim=mapper.over_sampler.sub_fraction.array, ) - ( - data_to_pix_unique, - data_weights, - pix_lengths, - ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( - data_pixels=w_tilde_data.shape[0], - pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index.astype( + weight_map = image.array / (noise_map.array**2) + weight_map = aa.Array2D(values=weight_map, mask=noise_map.mask) + + psf_weighted_data = aa.util.inversion_imaging.psf_weighted_data_from( + weight_map_native=weight_map.native.array, + kernel_native=kernel.native.array, + native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype( "int" ), - pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, - pix_pixels=mapper.params, - sub_size=grid.over_sample_size.array, ) - data_vector_via_w_tilde = ( - aa.util.inversion_imaging_numba.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=w_tilde_data, - data_to_pix_unique=data_to_pix_unique.astype("int"), - data_weights=data_weights, - pix_lengths=pix_lengths.astype("int"), - pix_pixels=pixelization.pixels, + data_vector_via_psf_weighted_noise = ( + aa.util.inversion_imaging.data_vector_via_psf_weighted_data_from( + psf_weighted_data=psf_weighted_data, + rows=rows, + cols=cols, + vals=vals, + S=pixelization.pixels, ) ) - assert data_vector_via_w_tilde == pytest.approx(data_vector, 1.0e-4) + assert data_vector_via_psf_weighted_noise == pytest.approx(data_vector, 1.0e-4) -def test__curvature_matrix_via_w_tilde_two_methods_agree(): +def test__curvature_matrix_via_psf_weighted_noise_two_methods_agree(): + mask = aa.Mask2D.circular(shape_native=(51, 51), pixel_scales=0.1, radius=2.0) noise_map = np.random.uniform(size=mask.shape_native) @@ -264,9 +276,15 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): psf = kernel - pixelization = aa.mesh.RectangularUniform(shape=(20, 20)) + sparse_operator = aa.ImagingSparseOperator.from_noise_map_and_psf( + data=noise_map, + noise_map=noise_map, + psf=psf.native, + ) + + mesh = aa.mesh.RectangularAdaptDensity(shape=(20, 20)) - mapper_grids = pixelization.mapper_grids_from( + mapper_grids = mesh.mapper_grids_from( mask=mask, border_relocator=None, source_plane_data_grid=mask.derive_grid.unmasked, @@ -276,14 +294,26 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): mapping_matrix = mapper.mapping_matrix - w_tilde = aa.util.inversion_imaging_numba.w_tilde_curvature_imaging_from( - noise_map_native=noise_map.native.array, - kernel_native=kernel.native.array, - native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype("int"), + rows, cols, vals = aa.util.mapper.sparse_triplets_from( + pix_indexes_for_sub=mapper.pix_indexes_for_sub_slim_index, + pix_weights_for_sub=mapper.pix_weights_for_sub_slim_index, + slim_index_for_sub=mapper.slim_index_for_sub_slim_index, + fft_index_for_masked_pixel=mask.fft_index_for_masked_pixel, + sub_fraction_slim=mapper.over_sampler.sub_fraction.array, + return_rows_slim=False, ) - curvature_matrix_via_w_tilde = aa.util.inversion.curvature_matrix_via_w_tilde_from( - w_tilde=w_tilde, mapping_matrix=mapping_matrix + curvature_matrix_via_sparse_operator = sparse_operator.curvature_matrix_diag_from( + rows, + cols, + vals, + S=mesh.shape[0] * mesh.shape[1], + ) + + curvature_matrix_via_sparse_operator = ( + aa.util.inversion_imaging.curvature_matrix_mirrored_from( + curvature_matrix=curvature_matrix_via_sparse_operator, + ) ) blurred_mapping_matrix = psf.convolved_mapping_matrix_from( @@ -294,84 +324,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): mapping_matrix=blurred_mapping_matrix, noise_map=noise_map, ) - assert curvature_matrix_via_w_tilde == pytest.approx(curvature_matrix, abs=1.0e-4) - -def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): - mask = aa.Mask2D.circular(shape_native=(51, 51), pixel_scales=0.1, radius=2.0) - - noise_map = np.random.uniform(size=mask.shape_native) - noise_map = aa.Array2D(values=noise_map, mask=mask) - - kernel = aa.Kernel2D.from_gaussian( - shape_native=(7, 7), pixel_scales=mask.pixel_scales, sigma=1.0, normalize=True + assert curvature_matrix_via_sparse_operator == pytest.approx( + curvature_matrix, rel=1.0e-3 ) - - psf = kernel - - pixelization = aa.mesh.RectangularUniform(shape=(20, 20)) - - for sub_size in range(1, 2, 3): - grid = aa.Grid2D.from_mask(mask=mask, over_sample_size=sub_size) - - mapper_grids = pixelization.mapper_grids_from( - mask=mask, - border_relocator=None, - source_plane_data_grid=grid, - ) - - mapper = aa.Mapper( - mapper_grids=mapper_grids, - regularization=None, - ) - - mapping_matrix = mapper.mapping_matrix - - ( - w_tilde_preload, - w_tilde_indexes, - w_tilde_lengths, - ) = aa.util.inversion_imaging_numba.w_tilde_curvature_preload_imaging_from( - noise_map_native=noise_map.native.array, - kernel_native=kernel.native.array, - native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype( - "int" - ), - ) - - ( - data_to_pix_unique, - data_weights, - pix_lengths, - ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( - data_pixels=w_tilde_lengths.shape[0], - pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index, - pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, - pix_pixels=mapper.params, - sub_size=grid.over_sample_size.array, - ) - - curvature_matrix_via_w_tilde = aa.util.inversion_imaging_numba.curvature_matrix_via_w_tilde_curvature_preload_imaging_from( - curvature_preload=w_tilde_preload, - curvature_indexes=w_tilde_indexes.astype("int"), - curvature_lengths=w_tilde_lengths.astype("int"), - data_to_pix_unique=data_to_pix_unique.astype("int"), - data_weights=data_weights, - pix_lengths=pix_lengths.astype("int"), - pix_pixels=pixelization.pixels, - ) - - blurred_mapping_matrix = psf.convolved_mapping_matrix_from( - mapping_matrix=mapping_matrix, - mask=mask, - ) - - curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=np.array(blurred_mapping_matrix), - noise_map=np.array(noise_map), - ) - - assert curvature_matrix_via_w_tilde == pytest.approx( - curvature_matrix, abs=1.0e-4 - ) diff --git a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py index df9ee014e..a23ac182e 100644 --- a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py +++ b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py @@ -68,7 +68,7 @@ def test__data_vector_via_transformed_mapping_matrix_from(): assert (data_vector_complex_via_blurred == data_vector_via_transformed).all() -def test__curvature_matrix_via_curvature_preload_from(): +def test__curvature_matrix_via_psf_precision_operator_from(): noise_map = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) uv_wavelengths = np.array( [[0.0001, 2.0, 3000.0, 50000.0, 200000.0], [3000.0, 2.0, 0.0001, 10.0, 5000.0]] @@ -90,8 +90,8 @@ def test__curvature_matrix_via_curvature_preload_from(): ] ) - curvature_preload = ( - aa.util.inversion_interferometer.w_tilde_curvature_preload_interferometer_from( + nufft_precision_operator = ( + aa.util.inversion_interferometer.nufft_precision_operator_from( noise_map_real=noise_map, uv_wavelengths=uv_wavelengths, shape_masked_pixels_2d=(3, 3), @@ -103,13 +103,17 @@ def test__curvature_matrix_via_curvature_preload_from(): [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2]] ) - w_tilde = aa.util.inversion_interferometer.w_tilde_via_preload_from( - curvature_preload=curvature_preload, - native_index_for_slim_index=native_index_for_slim_index, + psf_weighted_noise = ( + aa.util.inversion_interferometer.nufft_weighted_noise_via_sparse_operator_from( + translation_invariant_kernel=nufft_precision_operator, + native_index_for_slim_index=native_index_for_slim_index, + ) ) - curvature_matrix_via_w_tilde = aa.util.inversion.curvature_matrix_via_w_tilde_from( - w_tilde=w_tilde, mapping_matrix=mapping_matrix + curvature_matrix_via_nufft_weighted_noise = ( + aa.util.inversion.curvature_matrix_diag_via_psf_weighted_noise_from( + psf_weighted_noise=psf_weighted_noise, mapping_matrix=mapping_matrix + ) ) pix_indexes_for_sub_slim_index = np.array( @@ -118,222 +122,20 @@ def test__curvature_matrix_via_curvature_preload_from(): pix_weights_for_sub_slim_index = np.ones(shape=(9, 1)) - w_tilde = aa.WTildeInterferometer( - curvature_preload=curvature_preload, + sparse_operator = aa.InterferometerSparseOperator.from_nufft_precision_operator( + nufft_precision_operator=nufft_precision_operator, dirty_image=None, - real_space_mask=grid.mask, ) - curvature_matrix_via_preload = aa.util.inversion_interferometer.curvature_matrix_via_w_tilde_interferometer_from( - fft_state=w_tilde.fft_state, - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - rect_index_for_mask_index=w_tilde.rect_index_for_mask_index, - pix_pixels=3, + curvature_matrix_via_preload = ( + sparse_operator.curvature_matrix_via_sparse_operator_from( + pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, + pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, + fft_index_for_masked_pixel=grid.mask.fft_index_for_masked_pixel, + pix_pixels=3, + ) ) - assert curvature_matrix_via_w_tilde == pytest.approx( + assert curvature_matrix_via_nufft_weighted_noise == pytest.approx( curvature_matrix_via_preload, 1.0e-4 ) - - -def test__identical_inversion_values_for_two_methods(): - real_space_mask = aa.Mask2D.all_false( - shape_native=(7, 7), - pixel_scales=0.1, - ) - - grid = aa.Grid2D.from_mask(mask=real_space_mask, over_sample_size=1) - - mesh = aa.mesh.Delaunay() - - mesh_grid = aa.Grid2D.no_mask( - values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]], - shape_native=(3, 2), - pixel_scales=1.0, - ) - - mesh_grid = aa.Grid2DIrregular(values=mesh_grid) - - mapper_grids = mesh.mapper_grids_from( - mask=real_space_mask, - border_relocator=None, - source_plane_data_grid=grid, - source_plane_mesh_grid=mesh_grid, - ) - - reg = aa.reg.Constant(coefficient=1.0) - - mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=reg) - - visibilities = aa.Visibilities( - visibilities=[ - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - ] - ) - noise_map = aa.VisibilitiesNoiseMap.ones(shape_slim=(7,)) - uv_wavelengths = np.ones(shape=(7, 2)) - - dataset = aa.Interferometer( - data=visibilities, - noise_map=noise_map, - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - transformer_class=aa.TransformerDFT, - ) - - dataset_w_tilde = dataset.apply_w_tilde() - - inversion_w_tilde = aa.Inversion( - dataset=dataset_w_tilde, - linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_positive_only_solver=True), - ) - - inversion_mapping_matrices = aa.Inversion( - dataset=dataset, - linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_positive_only_solver=True), - ) - - assert (inversion_w_tilde.data == inversion_mapping_matrices.data).all() - assert (inversion_w_tilde.noise_map == inversion_mapping_matrices.noise_map).all() - assert ( - inversion_w_tilde.linear_obj_list[0] - == inversion_mapping_matrices.linear_obj_list[0] - ) - assert ( - inversion_w_tilde.regularization_list[0] - == inversion_mapping_matrices.regularization_list[0] - ) - assert ( - inversion_w_tilde.regularization_matrix - == inversion_mapping_matrices.regularization_matrix - ).all() - - assert inversion_w_tilde.data_vector == pytest.approx( - inversion_mapping_matrices.data_vector, abs=1.0e-2 - ) - assert inversion_w_tilde.curvature_matrix == pytest.approx( - inversion_mapping_matrices.curvature_matrix, abs=1.0e-2 - ) - assert inversion_w_tilde.curvature_reg_matrix == pytest.approx( - inversion_mapping_matrices.curvature_reg_matrix, abs=1.0e-2 - ) - - assert inversion_w_tilde.reconstruction == pytest.approx( - inversion_mapping_matrices.reconstruction, abs=1.0e-1 - ) - assert inversion_w_tilde.mapped_reconstructed_image.array == pytest.approx( - inversion_mapping_matrices.mapped_reconstructed_image.array, abs=1.0e-1 - ) - assert inversion_w_tilde.mapped_reconstructed_data.array == pytest.approx( - inversion_mapping_matrices.mapped_reconstructed_data.array, abs=1.0e-1 - ) - - -def test__identical_inversion_source_and_image_loops(): - real_space_mask = aa.Mask2D.all_false( - shape_native=(7, 7), - pixel_scales=0.1, - ) - - grid = aa.Grid2D.from_mask(mask=real_space_mask, over_sample_size=1) - - mesh = aa.mesh.Delaunay() - - mesh_grid = aa.Grid2D.no_mask( - values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]], - shape_native=(3, 2), - pixel_scales=1.0, - ) - - mesh_grid = aa.Grid2DIrregular(values=mesh_grid) - - mapper_grids = mesh.mapper_grids_from( - mask=real_space_mask, - border_relocator=None, - source_plane_data_grid=grid, - source_plane_mesh_grid=mesh_grid, - ) - - reg = aa.reg.Constant(coefficient=0.0) - - mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=reg) - - visibilities = aa.Visibilities( - visibilities=[ - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - 1.0 + 0.0j, - ] - ) - noise_map = aa.VisibilitiesNoiseMap.ones(shape_slim=(7,)) - uv_wavelengths = np.ones(shape=(7, 2)) - - dataset = aa.Interferometer( - data=visibilities, - noise_map=noise_map, - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - transformer_class=aa.TransformerDFT, - ) - - dataset_w_tilde = dataset.apply_w_tilde() - - inversion_image_loop = aa.Inversion( - dataset=dataset_w_tilde, - linear_obj_list=[mapper], - settings=aa.SettingsInversion( - use_source_loop=False, use_positive_only_solver=True - ), - ) - - inversion_source_loop = aa.Inversion( - dataset=dataset_w_tilde, - linear_obj_list=[mapper], - settings=aa.SettingsInversion( - use_source_loop=True, use_positive_only_solver=True - ), - ) - - assert (inversion_image_loop.data == inversion_source_loop.data).all() - assert (inversion_image_loop.noise_map == inversion_source_loop.noise_map).all() - assert ( - inversion_image_loop.linear_obj_list[0] - == inversion_source_loop.linear_obj_list[0] - ) - assert ( - inversion_image_loop.regularization_list[0] - == inversion_source_loop.regularization_list[0] - ) - assert ( - inversion_image_loop.regularization_matrix - == inversion_source_loop.regularization_matrix - ).all() - - assert inversion_image_loop.curvature_matrix == pytest.approx( - inversion_source_loop.curvature_matrix, 1.0e-8 - ) - assert inversion_image_loop.curvature_reg_matrix == pytest.approx( - inversion_source_loop.curvature_reg_matrix, 1.0e-8 - ) - assert inversion_image_loop.reconstruction == pytest.approx( - inversion_source_loop.reconstruction, 1.0e-2 - ) - assert inversion_image_loop.mapped_reconstructed_image.array == pytest.approx( - inversion_source_loop.mapped_reconstructed_image.array, 1.0e-2 - ) - assert inversion_image_loop.mapped_reconstructed_data.array == pytest.approx( - inversion_source_loop.mapped_reconstructed_data.array, 1.0e-2 - ) diff --git a/test_autoarray/inversion/inversion/test_abstract.py b/test_autoarray/inversion/inversion/test_abstract.py index 47c462de9..0d9cf7edb 100644 --- a/test_autoarray/inversion/inversion/test_abstract.py +++ b/test_autoarray/inversion/inversion/test_abstract.py @@ -84,7 +84,7 @@ def test__mapping_matrix(): assert inversion.mapping_matrix == pytest.approx(mapping_matrix, 1.0e-4) -def test__curvature_matrix__via_w_tilde__identical_to_mapping(): +def test__curvature_matrix__via_sparse_operator__identical_to_mapping(): mask = aa.Mask2D( mask=[ [True, True, True, True, True, True, True], @@ -131,10 +131,10 @@ def test__curvature_matrix__via_w_tilde__identical_to_mapping(): masked_dataset = dataset.apply_mask(mask=mask) - masked_dataset_w_tilde = masked_dataset.apply_w_tilde() + masked_dataset_sparse_operator = masked_dataset.apply_sparse_operator_cpu() - inversion_w_tilde = aa.Inversion( - dataset=masked_dataset_w_tilde, + inversion_sparse_operator = aa.Inversion( + dataset=masked_dataset_sparse_operator, linear_obj_list=[mapper_0, mapper_1], ) @@ -143,12 +143,12 @@ def test__curvature_matrix__via_w_tilde__identical_to_mapping(): linear_obj_list=[mapper_0, mapper_1], ) - assert inversion_w_tilde.curvature_matrix == pytest.approx( + assert inversion_sparse_operator.curvature_matrix == pytest.approx( inversion_mapping.curvature_matrix, 1.0e-4 ) -def test__curvature_matrix_via_w_tilde__includes_source_interpolation__identical_to_mapping(): +def test__curvature_matrix_via_sparse_operator__includes_source_interpolation__identical_to_mapping(): mask = aa.Mask2D( mask=[ [True, True, True, True, True, True, True], @@ -206,10 +206,10 @@ def test__curvature_matrix_via_w_tilde__includes_source_interpolation__identical masked_dataset = dataset.apply_mask(mask=mask) - masked_dataset_w_tilde = masked_dataset.apply_w_tilde() + masked_dataset_sparse_operator = masked_dataset.apply_sparse_operator_cpu() - inversion_w_tilde = aa.Inversion( - dataset=masked_dataset_w_tilde, + inversion_sparse_operator = aa.Inversion( + dataset=masked_dataset_sparse_operator, linear_obj_list=[mapper_0, mapper_1], ) @@ -218,7 +218,7 @@ def test__curvature_matrix_via_w_tilde__includes_source_interpolation__identical linear_obj_list=[mapper_0, mapper_1], ) - assert inversion_w_tilde.curvature_matrix == pytest.approx( + assert inversion_sparse_operator.curvature_matrix == pytest.approx( inversion_mapping.curvature_matrix, 1.0e-4 ) diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index 25cea16e3..436867a61 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -25,12 +25,14 @@ def test__inversion_imaging__via_linear_obj_func_list(masked_imaging_7x7_no_blur assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) assert inversion.reconstruction == pytest.approx(np.array([2.0]), 1.0e-4) - # Overwrites use_w_tilde to false. + # Overwrites use_sparse_operator to false. - masked_imaging_7x7_no_blur_w_tilde = masked_imaging_7x7_no_blur.apply_w_tilde() + masked_imaging_7x7_no_blur_sparse_operator = ( + masked_imaging_7x7_no_blur.apply_sparse_operator_cpu() + ) inversion = aa.Inversion( - dataset=masked_imaging_7x7_no_blur_w_tilde, + dataset=masked_imaging_7x7_no_blur_sparse_operator, linear_obj_list=[linear_obj], ) @@ -78,15 +80,16 @@ def test__inversion_imaging__via_mapper( # ) assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) - masked_imaging_7x7_no_blur_w_tilde = masked_imaging_7x7_no_blur.apply_w_tilde() + masked_imaging_7x7_no_blur_sparse_operator = ( + masked_imaging_7x7_no_blur.apply_sparse_operator_cpu() + ) inversion = aa.Inversion( - dataset=masked_imaging_7x7_no_blur_w_tilde, + dataset=masked_imaging_7x7_no_blur_sparse_operator, linear_obj_list=[rectangular_mapper_7x7_3x3], ) assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangularUniform) - assert isinstance(inversion, aa.InversionImagingWTilde) assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( 7.257175708246, 1.0e-4 ) @@ -103,12 +106,11 @@ def test__inversion_imaging__via_mapper( assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) inversion = aa.Inversion( - dataset=masked_imaging_7x7_no_blur_w_tilde, + dataset=masked_imaging_7x7_no_blur_sparse_operator, linear_obj_list=[delaunay_mapper_9_3x3], ) assert isinstance(inversion.linear_obj_list[0], aa.MapperDelaunay) - assert isinstance(inversion, aa.InversionImagingWTilde) assert inversion.log_det_curvature_reg_matrix_term == pytest.approx(10.6674, 1.0e-4) assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) @@ -126,10 +128,12 @@ def test__inversion_imaging__via_regularizations( mapper = copy.copy(delaunay_mapper_9_3x3) mapper.regularization = regularization_constant - masked_imaging_7x7_no_blur_w_tilde = masked_imaging_7x7_no_blur.apply_w_tilde() + masked_imaging_7x7_no_blur_sparse_operator = ( + masked_imaging_7x7_no_blur.apply_sparse_operator_cpu() + ) inversion = aa.Inversion( - dataset=masked_imaging_7x7_no_blur_w_tilde, + dataset=masked_imaging_7x7_no_blur_sparse_operator, linear_obj_list=[mapper], ) @@ -143,7 +147,7 @@ def test__inversion_imaging__via_regularizations( mapper.regularization = regularization_adaptive_brightness inversion = aa.Inversion( - dataset=masked_imaging_7x7_no_blur_w_tilde, + dataset=masked_imaging_7x7_no_blur_sparse_operator, linear_obj_list=[mapper], ) @@ -208,10 +212,12 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper( assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3][0] < 1.0e-4 assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) - masked_imaging_7x7_no_blur_w_tilde = masked_imaging_7x7_no_blur.apply_w_tilde() + masked_imaging_7x7_no_blur_sparse_operator = ( + masked_imaging_7x7_no_blur.apply_sparse_operator_cpu() + ) inversion = aa.Inversion( - dataset=masked_imaging_7x7_no_blur_w_tilde, + dataset=masked_imaging_7x7_no_blur_sparse_operator, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], settings=aa.SettingsInversion( no_regularization_add_to_curvature_diag_value=False, @@ -220,7 +226,6 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper( assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObj) assert isinstance(inversion.linear_obj_list[1], aa.MapperRectangularUniform) - assert isinstance(inversion, aa.InversionImagingWTilde) assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( 7.2571757082469945, 1.0e-4 ) @@ -268,14 +273,14 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper__force_edge_pixels_t assert isinstance(inversion, aa.InversionImagingMapping) -def test__inversion_imaging__compare_mapping_and_w_tilde_values( +def test__inversion_imaging__compare_mapping_and_sparse_operator_values( masked_imaging_7x7, delaunay_mapper_9_3x3 ): - masked_imaging_7x7_w_tilde = masked_imaging_7x7.apply_w_tilde() + masked_imaging_7x7_sparse_operator = masked_imaging_7x7.apply_sparse_operator_cpu() - inversion_w_tilde = aa.Inversion( - dataset=masked_imaging_7x7_w_tilde, + inversion_sparse_operator = aa.Inversion( + dataset=masked_imaging_7x7_sparse_operator, linear_obj_list=[delaunay_mapper_9_3x3], ) @@ -285,16 +290,16 @@ def test__inversion_imaging__compare_mapping_and_w_tilde_values( settings=aa.SettingsInversion(), ) - assert inversion_w_tilde._curvature_matrix_mapper_diag == pytest.approx( + assert inversion_sparse_operator._curvature_matrix_mapper_diag == pytest.approx( inversion_mapping._curvature_matrix_mapper_diag, 1.0e-4 ) - assert inversion_w_tilde.reconstruction == pytest.approx( + assert inversion_sparse_operator.reconstruction == pytest.approx( inversion_mapping.reconstruction, 1.0e-4 ) - assert inversion_w_tilde.mapped_reconstructed_image.array == pytest.approx( + assert inversion_sparse_operator.mapped_reconstructed_image.array == pytest.approx( inversion_mapping.mapped_reconstructed_image.array, 1.0e-4 ) - assert inversion_w_tilde.log_det_curvature_reg_matrix_term == pytest.approx( + assert inversion_sparse_operator.log_det_curvature_reg_matrix_term == pytest.approx( inversion_mapping.log_det_curvature_reg_matrix_term ) @@ -349,7 +354,7 @@ def test__inversion_imaging__linear_obj_func_and_non_func_give_same_terms( ) -def test__inversion_imaging__linear_obj_func_with_w_tilde( +def test__inversion_imaging__linear_obj_func_with_sparse_operator( masked_imaging_7x7, rectangular_mapper_7x7_3x3, delaunay_mapper_9_3x3, @@ -373,28 +378,29 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( settings=aa.SettingsInversion(use_positive_only_solver=True), ) - masked_imaging_7x7_w_tilde = masked_imaging_7x7.apply_w_tilde() + masked_imaging_7x7_sparse_operator = masked_imaging_7x7.apply_sparse_operator_cpu() - inversion_w_tilde = aa.Inversion( - dataset=masked_imaging_7x7_w_tilde, + inversion_sparse_operator = aa.Inversion( + dataset=masked_imaging_7x7_sparse_operator, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], settings=aa.SettingsInversion(use_positive_only_solver=True), ) assert inversion_mapping.data_vector == pytest.approx( - inversion_w_tilde.data_vector, 1.0e-4 + inversion_sparse_operator.data_vector, 1.0e-4 ) + assert inversion_mapping.curvature_matrix == pytest.approx( - inversion_w_tilde.curvature_matrix, 1.0e-4 + inversion_sparse_operator.curvature_matrix, 1.0e-4 ) assert inversion_mapping.curvature_reg_matrix == pytest.approx( - inversion_w_tilde.curvature_reg_matrix, 1.0e-4 + inversion_sparse_operator.curvature_reg_matrix, 1.0e-4 ) assert inversion_mapping.reconstruction == pytest.approx( - inversion_w_tilde.reconstruction, 1.0e-4 + inversion_sparse_operator.reconstruction, 1.0e-4 ) assert inversion_mapping.mapped_reconstructed_image.array == pytest.approx( - inversion_w_tilde.mapped_reconstructed_image.array, 1.0e-4 + inversion_sparse_operator.mapped_reconstructed_image.array, 1.0e-4 ) linear_obj_1 = aa.m.MockLinearObjFuncList( @@ -417,10 +423,10 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( settings=aa.SettingsInversion(), ) - masked_imaging_7x7_w_tilde = masked_imaging_7x7.apply_w_tilde() + masked_imaging_7x7_sparse_operator = masked_imaging_7x7.apply_sparse_operator_cpu() - inversion_w_tilde = aa.Inversion( - dataset=masked_imaging_7x7_w_tilde, + inversion_sparse_operator = aa.Inversion( + dataset=masked_imaging_7x7_sparse_operator, linear_obj_list=[ rectangular_mapper_7x7_3x3, linear_obj, @@ -431,10 +437,10 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( ) assert inversion_mapping.data_vector == pytest.approx( - inversion_w_tilde.data_vector, 1.0e-4 + inversion_sparse_operator.data_vector, 1.0e-4 ) assert inversion_mapping.curvature_matrix == pytest.approx( - inversion_w_tilde.curvature_matrix, 1.0e-4 + inversion_sparse_operator.curvature_matrix, 1.0e-4 ) diff --git a/test_autoarray/inversion/inversion/test_inversion_util.py b/test_autoarray/inversion/inversion/test_inversion_util.py index 0b1661f2e..cd8873e0b 100644 --- a/test_autoarray/inversion/inversion/test_inversion_util.py +++ b/test_autoarray/inversion/inversion/test_inversion_util.py @@ -3,8 +3,8 @@ import pytest -def test__curvature_matrix_from_w_tilde(): - w_tilde = np.array( +def test__curvature_matrix_diag_via_psf_weighted_noise_from(): + psf_weighted_noise = np.array( [ [1.0, 2.0, 3.0, 4.0], [2.0, 1.0, 2.0, 3.0], @@ -17,8 +17,10 @@ def test__curvature_matrix_from_w_tilde(): [[1.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]] ) - curvature_matrix = aa.util.inversion.curvature_matrix_via_w_tilde_from( - w_tilde=w_tilde, mapping_matrix=mapping_matrix + curvature_matrix = ( + aa.util.inversion.curvature_matrix_diag_via_psf_weighted_noise_from( + psf_weighted_noise=psf_weighted_noise, mapping_matrix=mapping_matrix + ) ) assert ( diff --git a/test_autoarray/inversion/inversion/test_settings_dict.py b/test_autoarray/inversion/inversion/test_settings_dict.py index 9f689c802..12cd0f759 100644 --- a/test_autoarray/inversion/inversion/test_settings_dict.py +++ b/test_autoarray/inversion/inversion/test_settings_dict.py @@ -16,8 +16,6 @@ def make_settings_dict(): "use_positive_only_solver": False, "positive_only_uses_p_initial": False, "no_regularization_add_to_curvature_diag_value": 1e-08, - "use_w_tilde_numpy": False, - "use_source_loop": False, "tolerance": 1e-08, "maxiter": 250, },