From d4d1994bc632869667dafbd5d261d3045bc3bab7 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 5 Feb 2026 13:09:18 +0000 Subject: [PATCH 01/13] stuff --- autoarray/dataset/imaging/dataset.py | 4 ++++ autoarray/dataset/interferometer/dataset.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 401990c9..8ef91049 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -504,6 +504,10 @@ def apply_sparse_operator( Whether to use JAX to compute W-Tilde. This requires JAX to be installed. """ + logger.info( + "IMAGING - Setting Up Sparse Operator For low Memory Pixelizations." + ) + sparse_operator = ( inversion_imaging_util.ImagingSparseOperator.from_noise_map_and_psf( data=self.data, diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 922aff96..7e72ddcd 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -198,7 +198,7 @@ def apply_sparse_operator( if nufft_precision_operator is None: logger.info( - "INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, CPU run times may exceed hours." + "INTERFEROMETER - Computing NUFFT Precision Operator; runtime scales with visibility count and mask resolution, CPU run times may exceed hours." ) nufft_precision_operator = self.psf_precision_operator_from( From e15260d2c19e1cc3ce1c86af6de5807b4b201b3d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 6 Feb 2026 17:16:02 +0000 Subject: [PATCH 02/13] couple of tests --- .../inversion/inversion/imaging/abstract.py | 5 +- .../inversion/inversion/imaging/mapping.py | 10 ++- .../inversion/inversion/inversion_util.py | 13 ++- autoarray/inversion/inversion/settings.py | 7 ++ autoarray/structures/arrays/kernel_2d.py | 80 +++++++++++++------ 5 files changed, 86 insertions(+), 29 deletions(-) diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index b863c1c6..7400a0ed 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -1,6 +1,8 @@ import numpy as np from typing import Dict, List, Union, Type +from autoconf import cached_property + from autoarray.dataset.imaging.dataset import Imaging from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList @@ -136,6 +138,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: if linear_func.operated_mapping_matrix_override is not None: operated_mapping_matrix = linear_func.operated_mapping_matrix_override else: + vvv operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_func.mapping_matrix, mask=self.mask, @@ -200,7 +203,7 @@ def data_linear_func_matrix_dict(self): return data_linear_func_matrix_dict - @property + @cached_property def mapper_operated_mapping_matrix_dict(self) -> Dict: """ The `operated_mapping_matrix` of a `Mapper` object describes the mappings between the observed data's values diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 7606e961..ff4af63f 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -76,7 +76,10 @@ def _data_vector_mapper(self) -> np.ndarray: param_range = mapper_param_range_list[i] operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( - mapping_matrix=mapper.mapping_matrix, mask=self.mask, xp=self._xp + mapping_matrix=mapper.mapping_matrix, + mask=self.mask, + use_mixed_precision=self.settings.use_mixed_precision, + xp=self._xp ) data_vector_mapper = ( @@ -135,7 +138,10 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: mapper_param_range_i = mapper_param_range_list[i] operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( - mapping_matrix=mapper_i.mapping_matrix, mask=self.mask, xp=self._xp + mapping_matrix=mapper_i.mapping_matrix, + mask=self.mask, + use_mixed_precision=self.settings.use_mixed_precision, + xp=self._xp ) diag = inversion_util.curvature_matrix_via_mapping_matrix_from( diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 854c331b..6c1609d9 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -84,6 +84,10 @@ def curvature_matrix_via_mapping_matrix_from( no_regularization_index_list: Optional[List] = None, settings: SettingsInversion = SettingsInversion(), xp=np, + *, + mp_gemm: bool = True, # mixed precision matmul + gemm_dtype=None, # e.g. xp.float32 + out_dtype=None, # e.g. xp.float64 ) -> np.ndarray: """ Returns the curvature matrix `F` from a blurred mapping matrix `f` and the 1D noise-map $\sigma$ @@ -97,8 +101,13 @@ def curvature_matrix_via_mapping_matrix_from( noise_map Flattened 1D array of the noise-map used by the inversion during the fit. """ - array = mapping_matrix / noise_map[:, None] - curvature_matrix = xp.dot(array.T, array) + if gemm_dtype is None: + gemm_dtype = xp.float32 if (mp_gemm and xp is not np) else mapping_matrix.dtype + + # form A in chosen dtype (usually float32 on device) + A = (mapping_matrix / noise_map[:, None]).astype(gemm_dtype) + + curvature_matrix = xp.dot(A.T, A) # float32 GEMM if A is float32 if add_to_curvature_diag and len(no_regularization_index_list) > 0: curvature_matrix = curvature_matrix_with_added_to_diag_from( diff --git a/autoarray/inversion/inversion/settings.py b/autoarray/inversion/inversion/settings.py index 73494226..ab639de0 100644 --- a/autoarray/inversion/inversion/settings.py +++ b/autoarray/inversion/inversion/settings.py @@ -10,6 +10,7 @@ class SettingsInversion: def __init__( self, + use_mixed_precision : bool = False, use_positive_only_solver: Optional[bool] = None, positive_only_uses_p_initial: Optional[bool] = None, use_border_relocator: Optional[bool] = None, @@ -24,6 +25,12 @@ def __init__( Parameters ---------- + use_mixed_precision + If `True`, the linear algebra calculations of the inversion are performed using single precision on a + targeted subset of functions which provide significant speed up when using a GPU (x4), reduces VRAM + use and are expected to have minimal impact on the accuracy of the results. If `False`, all linear algebra + calculations are performed using double precision, which is the default and is more accurate but + slower on a GPU. use_positive_only_solver Whether to use a positive-only linear system solver, which requires that every reconstructed value is positive but is computationally much slower than the default solver (which allows for positive and diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index abc74dc3..76bbd2b5 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -533,6 +533,7 @@ def mapping_matrix_native_from( mask: "Mask2D", blurring_mapping_matrix: Optional[np.ndarray] = None, blurring_mask: Optional["Mask2D"] = None, + use_mixed_precision: bool = False, xp=np, ) -> np.ndarray: """ @@ -558,6 +559,10 @@ def mapping_matrix_native_from( Mask defining the blurring region pixels. Must be provided if `blurring_mapping_matrix` is given and `slim_to_native_blurring_tuple` is not already cached. + use_mixed_precision + If True, the mapping matrices are cast to single precision (float32) to + speed up GPU computations and reduce VRAM usage. If False, double precision + (float64) is used for maximum accuracy. Returns ------- @@ -566,33 +571,29 @@ def mapping_matrix_native_from( Contains contributions from both the main mapping matrix and, if provided, the blurring mapping matrix. """ + dtype_native = xp.float32 if use_mixed_precision else xp.float64 + n_src = mapping_matrix.shape[1] - # Allocate full native grid (ny, nx, n_src) - mapping_matrix_native = xp.zeros( - mask.shape + (n_src,), dtype=mapping_matrix.dtype - ) + mapping_matrix_native = xp.zeros(mask.shape + (n_src,), dtype=dtype_native) + + # Cast inputs to the target dtype to avoid implicit up/downcasts inside scatter + mm = mapping_matrix if mapping_matrix.dtype == dtype_native else xp.asarray(mapping_matrix, dtype=dtype_native) - # Scatter main mapping matrix into native cube if xp.__name__.startswith("jax"): - mapping_matrix_native = mapping_matrix_native.at[ - mask.slim_to_native_tuple - ].set(mapping_matrix) + mapping_matrix_native = mapping_matrix_native.at[mask.slim_to_native_tuple].set(mm) else: - mapping_matrix_native[mask.slim_to_native_tuple] = mapping_matrix - - # Optionally scatter blurring mapping matrix + mapping_matrix_native[mask.slim_to_native_tuple] = np.asarray(mm) if blurring_mapping_matrix is not None: + bm = blurring_mapping_matrix + if getattr(bm, "dtype", None) != dtype_native: + bm = xp.asarray(bm, dtype=dtype_native) if xp.__name__.startswith("jax"): - mapping_matrix_native = mapping_matrix_native.at[ - blurring_mask.slim_to_native_tuple - ].set(blurring_mapping_matrix) + mapping_matrix_native = mapping_matrix_native.at[blurring_mask.slim_to_native_tuple].set(bm) else: - mapping_matrix_native[blurring_mask.slim_to_native_tuple] = ( - blurring_mapping_matrix - ) + mapping_matrix_native[blurring_mask.slim_to_native_tuple] = np.asarray(bm) return mapping_matrix_native @@ -730,6 +731,7 @@ def convolved_mapping_matrix_from( blurring_mapping_matrix=None, blurring_mask: Optional[Mask2D] = None, jax_method="direct", + use_mixed_precision: bool = False, xp=np, ): """ @@ -770,12 +772,19 @@ def convolved_mapping_matrix_from( Mapping matrix for the blurring region, outside the mask core. jax_method : str Backend passed to real-space convolution if ``use_fft=False``. + use_mixed_precision + If `True`, the FFT is performed using single precision, which provide significant speed up when using a + GPU (x4), reduces VRAM use and is expected to have minimal impact on the accuracy of the results. If `False`, + the FFT is performed using double precision, which is the default and is more accurate but slower on a GPU. Returns ------- ndarray of shape (N_pix, N_src) Convolved mapping matrix in slim form. """ + # ------------------------------------------------------------------------- + # NumPy path unchanged + # ------------------------------------------------------------------------- if xp is np: return self.convolved_mapping_matrix_via_real_space_np_from( mapping_matrix=mapping_matrix, @@ -785,6 +794,9 @@ def convolved_mapping_matrix_from( xp=xp, ) + # ------------------------------------------------------------------------- + # Non-FFT JAX path unchanged + # ------------------------------------------------------------------------- if not self.use_fft: return self.convolved_mapping_matrix_via_real_space_from( mapping_matrix=mapping_matrix, @@ -796,34 +808,50 @@ def convolved_mapping_matrix_from( ) import jax + import jax.numpy as jnp + # ------------------------------------------------------------------------- + # Validate cached FFT shapes / state + # ------------------------------------------------------------------------- if self.fft_shape is None: - full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=mask) - raise ValueError( f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n" f"Expected mapping matrix padded to match FFT shape of PSF.\n" f"PSF fft_shape: {fft_shape}, mask shape: {mask.shape}, " f"mapping_matrix shape: {getattr(mapping_matrix, 'shape', 'unknown')}." ) - else: - fft_shape = self.fft_shape full_shape = self.full_shape mask_shape = self.mask_shape fft_psf_mapping = self.fft_psf_mapping + # ------------------------------------------------------------------------- + # Mixed precision dtypes (JAX only) + # ------------------------------------------------------------------------- + fft_real_dtype = jnp.float32 if use_mixed_precision else jnp.float64 + fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128 + + # Ensure PSF FFT dtype matches the FFT path + fft_psf_mapping = jnp.asarray(fft_psf_mapping, dtype=fft_complex_dtype) + + # ------------------------------------------------------------------------- + # Build native cube in the FFT dtype (THIS IS THE KEY) + # Requires mapping_matrix_native_from to accept dtype_native kwarg. + # ------------------------------------------------------------------------- mapping_matrix_native = self.mapping_matrix_native_from( mapping_matrix=mapping_matrix, mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, blurring_mask=blurring_mask, + use_mixed_precision=use_mixed_precision, xp=xp, ) + # ------------------------------------------------------------------------- # FFT convolution + # ------------------------------------------------------------------------- fft_mapping_matrix_native = xp.fft.rfft2( mapping_matrix_native, s=fft_shape, axes=(0, 1) ) @@ -833,7 +861,9 @@ def convolved_mapping_matrix_from( axes=(0, 1), ) - # crop back + # ------------------------------------------------------------------------- + # Crop back to mask-shape + # ------------------------------------------------------------------------- start_indices = tuple( (full_size - out_size) // 2 for full_size, out_size in zip(full_shape, mask_shape) @@ -846,8 +876,10 @@ def convolved_mapping_matrix_from( out_shape_full, ) - # return slim form - return blurred_mapping_matrix_native[mask.slim_to_native_tuple] + # Return slim form + blurred_slim = blurred_mapping_matrix_native[mask.slim_to_native_tuple] + + return blurred_slim def rescaled_with_odd_dimensions_from( self, rescale_factor: float, normalize: bool = False From 456905c1ecf329d6221fca9cdd5d0ed16c9e1df6 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 7 Feb 2026 11:54:48 +0000 Subject: [PATCH 03/13] minor --- autoarray/structures/decorators/abstract.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autoarray/structures/decorators/abstract.py b/autoarray/structures/decorators/abstract.py index 5e4c86c4..b01c249e 100644 --- a/autoarray/structures/decorators/abstract.py +++ b/autoarray/structures/decorators/abstract.py @@ -61,7 +61,6 @@ def __init__(self, func, obj, grid, xp=np, *args, **kwargs): def _xp(self): if self.use_jax: import jax.numpy as jnp - return jnp return np From 411ded5e4f232b7f0d4b49b339acb8296c490a41 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 7 Feb 2026 12:01:52 +0000 Subject: [PATCH 04/13] remove redunant methods from mapping --- .../inversion/inversion/imaging/abstract.py | 6 +- .../inversion/inversion/imaging/mapping.py | 91 ------------------- autoarray/inversion/inversion/settings.py | 1 + .../inversion/inversion/test_factory.py | 3 - 4 files changed, 3 insertions(+), 98 deletions(-) diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 7400a0ed..ab0c7918 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -1,8 +1,6 @@ import numpy as np from typing import Dict, List, Union, Type -from autoconf import cached_property - from autoarray.dataset.imaging.dataset import Imaging from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList @@ -97,6 +95,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_obj.mapping_matrix, mask=self.mask, + use_mixed_precision=self.settings.use_mixed_precision, xp=self._xp, ) if linear_obj.operated_mapping_matrix_override is None @@ -138,7 +137,6 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: if linear_func.operated_mapping_matrix_override is not None: operated_mapping_matrix = linear_func.operated_mapping_matrix_override else: - vvv operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_func.mapping_matrix, mask=self.mask, @@ -203,7 +201,7 @@ def data_linear_func_matrix_dict(self): return data_linear_func_matrix_dict - @cached_property + @property def mapper_operated_mapping_matrix_dict(self) -> Dict: """ The `operated_mapping_matrix` of a `Mapper` object describes the mappings between the observed data's values diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index ff4af63f..00c618eb 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -53,47 +53,6 @@ def __init__( xp=xp, ) - @property - def _data_vector_mapper(self) -> np.ndarray: - """ - Returns the `data_vector` of all mappers, a 1D vector whose values are solved for by the simultaneous - linear equations constructed by this object. The object is described in full in the method `data_vector`. - - This method is used to compute part of the `data_vector` if there are also linear function list objects - in the inversion, and is separated into a separate method to enable preloading of the mapper `data_vector`. - """ - - if not self.has(cls=AbstractMapper): - return - - data_vector = np.zeros(self.total_params) - - mapper_list = self.cls_list_from(cls=AbstractMapper) - mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) - - for i in range(len(mapper_list)): - mapper = mapper_list[i] - param_range = mapper_param_range_list[i] - - operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( - mapping_matrix=mapper.mapping_matrix, - mask=self.mask, - use_mixed_precision=self.settings.use_mixed_precision, - xp=self._xp - ) - - data_vector_mapper = ( - inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=operated_mapping_matrix, - image=self.data, - noise_map=self.noise_map, - ) - ) - - data_vector[param_range[0] : param_range[1],] = data_vector_mapper - - return data_vector - @cached_property def data_vector(self) -> np.ndarray: """ @@ -114,56 +73,6 @@ def data_vector(self) -> np.ndarray: noise_map=self.noise_map.array, ) - @property - def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: - """ - Returns the diagonal regions of the `curvature_matrix`, a 2D matrix which uses the mappings between the data - and the linear objects to construct the simultaneous linear equations. The object is described in full in - the method `curvature_matrix`. - - This method computes the diagonal entries of all mapper objects in the `curvature_matrix`. It is separate from - other calculations to enable preloading of this calculation. - """ - - if not self.has(cls=AbstractMapper): - return None - - curvature_matrix = np.zeros((self.total_params, self.total_params)) - - mapper_list = self.cls_list_from(cls=AbstractMapper) - mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) - - for i in range(len(mapper_list)): - mapper_i = mapper_list[i] - mapper_param_range_i = mapper_param_range_list[i] - - operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( - mapping_matrix=mapper_i.mapping_matrix, - mask=self.mask, - use_mixed_precision=self.settings.use_mixed_precision, - xp=self._xp - ) - - diag = inversion_util.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=operated_mapping_matrix, - noise_map=self.noise_map, - settings=self.settings, - add_to_curvature_diag=True, - no_regularization_index_list=self.no_regularization_index_list, - xp=self._xp, - ) - - curvature_matrix[ - mapper_param_range_i[0] : mapper_param_range_i[1], - mapper_param_range_i[0] : mapper_param_range_i[1], - ] = diag - - curvature_matrix = inversion_util.curvature_matrix_mirrored_from( - curvature_matrix=curvature_matrix, xp=self._xp - ) - - return curvature_matrix - @cached_property def curvature_matrix(self): """ diff --git a/autoarray/inversion/inversion/settings.py b/autoarray/inversion/inversion/settings.py index ab639de0..071e8ffc 100644 --- a/autoarray/inversion/inversion/settings.py +++ b/autoarray/inversion/inversion/settings.py @@ -48,6 +48,7 @@ def __init__( For an interferometer inversion using the linear operators method, sets the maximum number of iterations of the solver (this input does nothing for dataset data and other interferometer methods). """ + self.use_mixed_precision = use_mixed_precision self._use_positive_only_solver = use_positive_only_solver self._positive_only_uses_p_initial = positive_only_uses_p_initial self._use_border_relocator = use_border_relocator diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index d0003e66..e45e6e41 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -290,9 +290,6 @@ def test__inversion_imaging__compare_mapping_and_sparse_operator_values( settings=aa.SettingsInversion(), ) - assert inversion_sparse_operator._curvature_matrix_mapper_diag == pytest.approx( - inversion_mapping._curvature_matrix_mapper_diag, 1.0e-4 - ) assert inversion_sparse_operator.reconstruction == pytest.approx( inversion_mapping.reconstruction, 1.0e-4 ) From 92809b2fcee654014b9b3c715bf0846547466383 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 7 Feb 2026 17:22:21 +0000 Subject: [PATCH 05/13] added full use_mixed_precision path --- .../inversion/inversion/imaging/mapping.py | 1 + .../inversion/inversion/inversion_util.py | 40 ++++++---- autoarray/inversion/inversion/settings.py | 2 +- autoarray/inversion/linear_obj/func_list.py | 3 + .../pixelization/mappers/abstract.py | 4 + .../inversion/pixelization/mappers/factory.py | 6 ++ .../pixelization/mappers/mapper_util.py | 56 ++++++++----- autoarray/operators/mock/mock_psf.py | 4 +- autoarray/structures/arrays/kernel_2d.py | 79 ++++++++++++------- autoarray/structures/decorators/abstract.py | 1 + 10 files changed, 132 insertions(+), 64 deletions(-) diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 00c618eb..a03b0a3a 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -98,6 +98,7 @@ def curvature_matrix(self): settings=self.settings, add_to_curvature_diag=True, no_regularization_index_list=self.no_regularization_index_list, + use_mixed_precision=self.settings.use_mixed_precision, xp=self._xp, ) diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 6c1609d9..92f65b06 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -78,16 +78,13 @@ def curvature_matrix_mirrored_from(curvature_matrix: np.ndarray, xp=np) -> np.nd def curvature_matrix_via_mapping_matrix_from( - mapping_matrix: np.ndarray, - noise_map: np.ndarray, + mapping_matrix: "np.ndarray", + noise_map: "np.ndarray", add_to_curvature_diag: bool = False, no_regularization_index_list: Optional[List] = None, - settings: SettingsInversion = SettingsInversion(), + settings: "SettingsInversion" = SettingsInversion(), + use_mixed_precision: bool = False, xp=np, - *, - mp_gemm: bool = True, # mixed precision matmul - gemm_dtype=None, # e.g. xp.float32 - out_dtype=None, # e.g. xp.float64 ) -> np.ndarray: """ Returns the curvature matrix `F` from a blurred mapping matrix `f` and the 1D noise-map $\sigma$ @@ -101,15 +98,26 @@ def curvature_matrix_via_mapping_matrix_from( noise_map Flattened 1D array of the noise-map used by the inversion during the fit. """ - if gemm_dtype is None: - gemm_dtype = xp.float32 if (mp_gemm and xp is not np) else mapping_matrix.dtype - - # form A in chosen dtype (usually float32 on device) - A = (mapping_matrix / noise_map[:, None]).astype(gemm_dtype) - - curvature_matrix = xp.dot(A.T, A) # float32 GEMM if A is float32 - - if add_to_curvature_diag and len(no_regularization_index_list) > 0: + # NumPy path: keep it simple + stable + if xp is np: + A = mapping_matrix / noise_map[:, None] + curvature_matrix = xp.dot(A.T, A) + else: + # Choose compute dtype + + compute_dtype = xp.float32 if use_mixed_precision else xp.float64 + out_dtype = xp.float64 # always return float64 for downstream stability + + A = mapping_matrix + w = (1.0 / noise_map).astype(compute_dtype) + A = A * w[:, None] + curvature_matrix = xp.dot(A.T, A).astype(out_dtype) + + if ( + add_to_curvature_diag + and no_regularization_index_list + and len(no_regularization_index_list) > 0 + ): curvature_matrix = curvature_matrix_with_added_to_diag_from( curvature_matrix=curvature_matrix, value=settings.no_regularization_add_to_curvature_diag_value, diff --git a/autoarray/inversion/inversion/settings.py b/autoarray/inversion/inversion/settings.py index 071e8ffc..2324c949 100644 --- a/autoarray/inversion/inversion/settings.py +++ b/autoarray/inversion/inversion/settings.py @@ -10,7 +10,7 @@ class SettingsInversion: def __init__( self, - use_mixed_precision : bool = False, + use_mixed_precision: bool = False, use_positive_only_solver: Optional[bool] = None, positive_only_uses_p_initial: Optional[bool] = None, use_border_relocator: Optional[bool] = None, diff --git a/autoarray/inversion/linear_obj/func_list.py b/autoarray/inversion/linear_obj/func_list.py index b329a0c2..82393835 100644 --- a/autoarray/inversion/linear_obj/func_list.py +++ b/autoarray/inversion/linear_obj/func_list.py @@ -7,6 +7,7 @@ from autoarray.inversion.linear_obj.neighbors import Neighbors from autoarray.inversion.linear_obj.unique_mappings import UniqueMappings from autoarray.inversion.regularization.abstract import AbstractRegularization +from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.type import Grid1D2DLike @@ -15,6 +16,7 @@ def __init__( self, grid: Grid1D2DLike, regularization: Optional[AbstractRegularization], + settings=SettingsInversion(), xp=np, ): """ @@ -45,6 +47,7 @@ def __init__( super().__init__(regularization=regularization, xp=xp) self.grid = grid + self.settings = settings @cached_property def neighbors(self) -> Neighbors: diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index 8c8ed067..7a29c61a 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -11,6 +11,7 @@ from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids from autoarray.inversion.regularization.abstract import AbstractRegularization +from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.mesh.abstract_2d import Abstract2DMesh @@ -25,6 +26,7 @@ def __init__( mapper_grids: MapperGrids, regularization: Optional[AbstractRegularization], border_relocator: BorderRelocator, + settings: SettingsInversion = SettingsInversion(), preloads=None, xp=np, ): @@ -90,6 +92,7 @@ def __init__( self.border_relocator = border_relocator self.mapper_grids = mapper_grids self.preloads = preloads + self.settings = settings @property def params(self) -> int: @@ -265,6 +268,7 @@ def mapping_matrix(self) -> np.ndarray: total_mask_pixels=self.over_sampler.mask.pixels_in_mask, slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index, sub_fraction=self.over_sampler.sub_fraction.array, + use_mixed_precision=self.settings.use_mixed_precision, xp=self._xp, ) diff --git a/autoarray/inversion/pixelization/mappers/factory.py b/autoarray/inversion/pixelization/mappers/factory.py index 94ae0d3e..001a7d72 100644 --- a/autoarray/inversion/pixelization/mappers/factory.py +++ b/autoarray/inversion/pixelization/mappers/factory.py @@ -4,6 +4,7 @@ from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray.inversion.regularization.abstract import AbstractRegularization +from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay @@ -13,6 +14,7 @@ def mapper_from( mapper_grids: MapperGrids, regularization: Optional[AbstractRegularization], border_relocator: Optional[BorderRelocator] = None, + settings=SettingsInversion(), preloads=None, xp=np, ): @@ -53,6 +55,8 @@ def mapper_from( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, + settings=settings, + preloads=preloads, xp=xp, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular): @@ -60,6 +64,8 @@ def mapper_from( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, + settings=settings, + preloads=preloads, xp=xp, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay): diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index 4474335c..385f2190 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -548,6 +548,7 @@ def mapping_matrix_from( total_mask_pixels: int, slim_index_for_sub_slim_index: np.ndarray, sub_fraction: np.ndarray, + use_mixed_precision: bool = False, xp=np, ) -> np.ndarray: """ @@ -621,39 +622,56 @@ def mapping_matrix_from( sub_fraction The fractional area each sub-pixel takes up in an pixel. """ + M_sub, B = pix_indexes_for_sub_slim_index.shape - M = total_mask_pixels - S = pixels + M = int(total_mask_pixels) + S = int(pixels) + + # Indices always int32 + pix_idx = xp.asarray(pix_indexes_for_sub_slim_index, dtype=xp.int32) + pix_size = xp.asarray(pix_size_for_sub_slim_index, dtype=xp.int32) + slim_parent = xp.asarray(slim_index_for_sub_slim_index, dtype=xp.int32) + + # Everything else computed in float64 + w64 = xp.asarray(pix_weights_for_sub_slim_index, dtype=xp.float64) + frac64 = xp.asarray(sub_fraction, dtype=xp.float64) + + # Output dtype only (big allocation) + out_dtype = xp.float32 if use_mixed_precision else xp.float64 # 1) Flatten - flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,) - flat_w = pix_weights_for_sub_slim_index.reshape(-1) # (M_sub*B,) - flat_parent = xp.repeat(slim_index_for_sub_slim_index, B) # (M_sub*B,) - flat_count = xp.repeat(pix_size_for_sub_slim_index, B) # (M_sub*B,) + flat_pixidx = pix_idx.reshape(-1) # (M_sub*B,) + flat_w = w64.reshape(-1) # float64 + flat_parent = xp.repeat(slim_parent, B) # int32 + flat_count = xp.repeat(pix_size, B) # int32 - # 2) Build valid mask: k < pix_size[i] - k = xp.tile(xp.arange(B), M_sub) # (M_sub*B,) - valid = k < flat_count # (M_sub*B,) + # 2) valid mask: k < pix_size[i] + k = xp.tile(xp.arange(B, dtype=xp.int32), M_sub) + valid = k < flat_count - # 3) Zero out invalid weights - flat_w = flat_w * valid.astype(flat_w.dtype) + # 3) Zero out invalid weights (float64) + flat_w = flat_w * valid.astype(xp.float64) # 4) Redirect -1 indices to extra bin S OUT = S flat_pixidx = xp.where(flat_pixidx < 0, OUT, flat_pixidx) - # 5) Multiply by sub_fraction of the slim row - flat_frac = xp.take(sub_fraction, flat_parent, axis=0) # (M_sub*B,) - flat_contrib = flat_w * flat_frac # (M_sub*B,) + # 5) Multiply by sub_fraction of the slim row (float64) + flat_frac = xp.take(frac64, flat_parent, axis=0) + flat_contrib64 = flat_w * flat_frac + + # 6) Scatter into (M × (S+1)) (destination float32 or float64) + mat = xp.zeros((M, S + 1), dtype=out_dtype) + + # Cast only at the write (keeps upstream math float64) + flat_contrib_out = flat_contrib64.astype(out_dtype) - # 6) Scatter into (M × (S+1)), summing duplicates - mat = xp.zeros((M, S + 1), dtype=flat_contrib.dtype) if xp.__name__.startswith("jax"): - mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib) + mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib_out) else: - xp.add.at(mat, (flat_parent, flat_pixidx), flat_contrib) + xp.add.at(mat, (flat_parent, flat_pixidx), flat_contrib_out) - # 7) Drop the extra column and return + # 7) Drop extra column return mat[:, :S] diff --git a/autoarray/operators/mock/mock_psf.py b/autoarray/operators/mock/mock_psf.py index c5231469..7f6d2d03 100644 --- a/autoarray/operators/mock/mock_psf.py +++ b/autoarray/operators/mock/mock_psf.py @@ -5,5 +5,7 @@ class MockPSF: def __init__(self, operated_mapping_matrix=None): self.operated_mapping_matrix = operated_mapping_matrix - def convolved_mapping_matrix_from(self, mapping_matrix, mask, xp=np): + def convolved_mapping_matrix_from( + self, mapping_matrix, mask, use_mixed_precision=False, xp=np + ): return self.operated_mapping_matrix diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 76bbd2b5..53b6080b 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -578,10 +578,16 @@ def mapping_matrix_native_from( mapping_matrix_native = xp.zeros(mask.shape + (n_src,), dtype=dtype_native) # Cast inputs to the target dtype to avoid implicit up/downcasts inside scatter - mm = mapping_matrix if mapping_matrix.dtype == dtype_native else xp.asarray(mapping_matrix, dtype=dtype_native) + mm = ( + mapping_matrix + if mapping_matrix.dtype == dtype_native + else xp.asarray(mapping_matrix, dtype=dtype_native) + ) if xp.__name__.startswith("jax"): - mapping_matrix_native = mapping_matrix_native.at[mask.slim_to_native_tuple].set(mm) + mapping_matrix_native = mapping_matrix_native.at[ + mask.slim_to_native_tuple + ].set(mm) else: mapping_matrix_native[mask.slim_to_native_tuple] = np.asarray(mm) @@ -591,13 +597,24 @@ def mapping_matrix_native_from( bm = xp.asarray(bm, dtype=dtype_native) if xp.__name__.startswith("jax"): - mapping_matrix_native = mapping_matrix_native.at[blurring_mask.slim_to_native_tuple].set(bm) + mapping_matrix_native = mapping_matrix_native.at[ + blurring_mask.slim_to_native_tuple + ].set(bm) else: - mapping_matrix_native[blurring_mask.slim_to_native_tuple] = np.asarray(bm) + mapping_matrix_native[blurring_mask.slim_to_native_tuple] = np.asarray( + bm + ) return mapping_matrix_native - def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np): + def convolved_image_from( + self, + image, + blurring_image, + jax_method="direct", + use_mixed_precision: bool = False, + xp=np, + ): """ Convolve an input masked image with this PSF. @@ -650,41 +667,54 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np ) import jax + import jax.numpy as jnp + import warnings + from autoarray.structures.arrays.uniform_2d import Array2D - if self.fft_shape is None: + # FFT path dtypes (JAX only) + fft_real_dtype = jnp.float32 if use_mixed_precision else jnp.float64 + fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128 + if self.fft_shape is None: + # Shapes computed on the fly full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=image.mask) - fft_psf = xp.fft.rfft2(self.stored_native.array, s=fft_shape, axes=(0, 1)) + + # Compute PSF FFT on the fly in the chosen precision + psf_native = jnp.asarray(self.stored_native.array, dtype=fft_real_dtype) + fft_psf = xp.fft.rfft2(psf_native, s=fft_shape, axes=(0, 1)).astype( + fft_complex_dtype + ) image_shape_original = image.shape_native + # Resize/pad the images as before image = image.resized_from(new_shape=fft_shape, mask_pad_value=1) if blurring_image is not None: blurring_image = blurring_image.resized_from( new_shape=fft_shape, mask_pad_value=1 ) - else: - + # Cached shapes/state fft_shape = self.fft_shape full_shape = self.full_shape mask_shape = self.mask_shape - fft_psf = self.fft_psf - # start with native image padded with zeros - image_both_native = xp.zeros(image.mask.shape, dtype=image.dtype) + # Use cached PSF FFT but ensure it matches chosen precision. + # IMPORTANT: casting here may create an extra buffer if self.fft_psf is complex128. + # Best practice is to cache a complex64 version on the object when MP is enabled. + fft_psf = jnp.asarray(self.fft_psf, dtype=fft_complex_dtype) + + # Build combined native image in the FFT dtype + image_both_native = xp.zeros(image.mask.shape, dtype=fft_real_dtype) image_both_native = image_both_native.at[image.mask.slim_to_native_tuple].set( - image.array + jnp.asarray(image.array, dtype=fft_real_dtype) ) - # add blurring contribution if provided if blurring_image is not None: - image_both_native = image_both_native.at[ blurring_image.mask.slim_to_native_tuple - ].set(blurring_image.array) - + ].set(jnp.asarray(blurring_image.array, dtype=fft_real_dtype)) else: warnings.warn( "No blurring_image provided. Only the direct image will be convolved. " @@ -699,25 +729,21 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np fft_psf * fft_image_native, s=fft_shape, axes=(0, 1) ) - out_shape_full = mask_shape - # Crop back to mask_shape start_indices = tuple( (full_size - out_size) // 2 for full_size, out_size in zip(full_shape, mask_shape) ) - blurred_image_native = jax.lax.dynamic_slice( - blurred_image_full, start_indices, out_shape_full + blurred_image_full, start_indices, mask_shape ) - blurred_image = Array2D( - values=blurred_image_native[image.mask.slim_to_native_tuple], - mask=image.mask, - ) + # Return slim form; optionally cast for downstream stability + blurred_slim = blurred_image_native[image.mask.slim_to_native_tuple] - if self.fft_shape is None: + blurred_image = Array2D(values=blurred_slim, mask=image.mask) + if self.fft_shape is None: blurred_image = blurred_image.resized_from( new_shape=image_shape_original, mask_pad_value=0 ) @@ -830,7 +856,6 @@ def convolved_mapping_matrix_from( # ------------------------------------------------------------------------- # Mixed precision dtypes (JAX only) # ------------------------------------------------------------------------- - fft_real_dtype = jnp.float32 if use_mixed_precision else jnp.float64 fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128 # Ensure PSF FFT dtype matches the FFT path diff --git a/autoarray/structures/decorators/abstract.py b/autoarray/structures/decorators/abstract.py index b01c249e..5e4c86c4 100644 --- a/autoarray/structures/decorators/abstract.py +++ b/autoarray/structures/decorators/abstract.py @@ -61,6 +61,7 @@ def __init__(self, func, obj, grid, xp=np, *args, **kwargs): def _xp(self): if self.use_jax: import jax.numpy as jnp + return jnp return np From cfe1b8953fb45b1ce7ee89281451d328f05f5d6f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 8 Feb 2026 09:10:58 +0000 Subject: [PATCH 06/13] minor --- autoarray/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 4a81a3b2..4944c874 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -8,7 +8,7 @@ from . import type from . import util from . import fixtures -from . import mock as m +from . import mock as m from .dataset import preprocess from .dataset.abstract.dataset import AbstractDataset From 5d0fab9eb01f39a2e9d9cf8d402793f76615fa46 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sun, 8 Feb 2026 09:20:56 +0000 Subject: [PATCH 07/13] Update autoarray/structures/arrays/kernel_2d.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- autoarray/structures/arrays/kernel_2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 53b6080b..52f3b97f 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -668,7 +668,6 @@ def convolved_image_from( import jax import jax.numpy as jnp - import warnings from autoarray.structures.arrays.uniform_2d import Array2D # FFT path dtypes (JAX only) From a50e5e69cc1868c7d8d6e68d1ef87c3f24dd2dd9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Feb 2026 09:21:58 +0000 Subject: [PATCH 08/13] Initial plan From 88a2a24b77c2c3991f586261ef04b328d5784f59 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sun, 8 Feb 2026 09:22:25 +0000 Subject: [PATCH 09/13] Update autoarray/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- autoarray/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 4944c874..4a81a3b2 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -8,7 +8,7 @@ from . import type from . import util from . import fixtures -from . import mock as m +from . import mock as m from .dataset import preprocess from .dataset.abstract.dataset import AbstractDataset From 9898d84aec02a248062b4607b32cf056bd2a2bc9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Feb 2026 09:23:25 +0000 Subject: [PATCH 10/13] Initial plan From 4e78a8c077174ba6f210c203ec04bcbec32db9d1 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sun, 8 Feb 2026 09:23:32 +0000 Subject: [PATCH 11/13] Update autoarray/structures/arrays/kernel_2d.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- autoarray/structures/arrays/kernel_2d.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 52f3b97f..4ccc1845 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -862,7 +862,8 @@ def convolved_mapping_matrix_from( # ------------------------------------------------------------------------- # Build native cube in the FFT dtype (THIS IS THE KEY) - # Requires mapping_matrix_native_from to accept dtype_native kwarg. + # This relies on mapping_matrix_native_from honoring the use_mixed_precision + # kwarg when constructing the native mapping matrix. # ------------------------------------------------------------------------- mapping_matrix_native = self.mapping_matrix_native_from( mapping_matrix=mapping_matrix, From 3ed7e96f73792f271d982dedca248f8b1c39d9c3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Feb 2026 09:26:02 +0000 Subject: [PATCH 12/13] Add settings parameter to MapperDelaunay factory call Co-authored-by: Jammy2211 <23455639+Jammy2211@users.noreply.github.com> --- autoarray/inversion/pixelization/mappers/factory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autoarray/inversion/pixelization/mappers/factory.py b/autoarray/inversion/pixelization/mappers/factory.py index 001a7d72..a7718e15 100644 --- a/autoarray/inversion/pixelization/mappers/factory.py +++ b/autoarray/inversion/pixelization/mappers/factory.py @@ -73,6 +73,7 @@ def mapper_from( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, + settings=settings, preloads=preloads, xp=xp, ) From 29b81188318b55562b418dd567b3a5c8ac48bfe7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Feb 2026 09:26:05 +0000 Subject: [PATCH 13/13] Remove redundant use_mixed_precision parameter, derive from settings Co-authored-by: Jammy2211 <23455639+Jammy2211@users.noreply.github.com> --- autoarray/inversion/inversion/imaging/mapping.py | 1 - autoarray/inversion/inversion/interferometer/mapping.py | 2 ++ autoarray/inversion/inversion/inversion_util.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index a03b0a3a..00c618eb 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -98,7 +98,6 @@ def curvature_matrix(self): settings=self.settings, add_to_curvature_diag=True, no_regularization_index_list=self.no_regularization_index_list, - use_mixed_precision=self.settings.use_mixed_precision, xp=self._xp, ) diff --git a/autoarray/inversion/inversion/interferometer/mapping.py b/autoarray/inversion/inversion/interferometer/mapping.py index e516bf0d..f1455b93 100644 --- a/autoarray/inversion/inversion/interferometer/mapping.py +++ b/autoarray/inversion/inversion/interferometer/mapping.py @@ -88,12 +88,14 @@ def curvature_matrix(self) -> np.ndarray: real_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from( mapping_matrix=self.operated_mapping_matrix.real, noise_map=self.noise_map.real, + settings=self.settings, xp=self._xp, ) imag_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from( mapping_matrix=self.operated_mapping_matrix.imag, noise_map=self.noise_map.imag, + settings=self.settings, xp=self._xp, ) diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 92f65b06..7d2e6400 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -83,7 +83,6 @@ def curvature_matrix_via_mapping_matrix_from( add_to_curvature_diag: bool = False, no_regularization_index_list: Optional[List] = None, settings: "SettingsInversion" = SettingsInversion(), - use_mixed_precision: bool = False, xp=np, ) -> np.ndarray: """ @@ -105,7 +104,7 @@ def curvature_matrix_via_mapping_matrix_from( else: # Choose compute dtype - compute_dtype = xp.float32 if use_mixed_precision else xp.float64 + compute_dtype = xp.float32 if settings.use_mixed_precision else xp.float64 out_dtype = xp.float64 # always return float64 for downstream stability A = mapping_matrix