diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 401990c9..8ef91049 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -504,6 +504,10 @@ def apply_sparse_operator( Whether to use JAX to compute W-Tilde. This requires JAX to be installed. """ + logger.info( + "IMAGING - Setting Up Sparse Operator For low Memory Pixelizations." + ) + sparse_operator = ( inversion_imaging_util.ImagingSparseOperator.from_noise_map_and_psf( data=self.data, diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 922aff96..7e72ddcd 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -198,7 +198,7 @@ def apply_sparse_operator( if nufft_precision_operator is None: logger.info( - "INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, CPU run times may exceed hours." + "INTERFEROMETER - Computing NUFFT Precision Operator; runtime scales with visibility count and mask resolution, CPU run times may exceed hours." ) nufft_precision_operator = self.psf_precision_operator_from( diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index b863c1c6..ab0c7918 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -95,6 +95,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_obj.mapping_matrix, mask=self.mask, + use_mixed_precision=self.settings.use_mixed_precision, xp=self._xp, ) if linear_obj.operated_mapping_matrix_override is None diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 7606e961..00c618eb 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -53,44 +53,6 @@ def __init__( xp=xp, ) - @property - def _data_vector_mapper(self) -> np.ndarray: - """ - Returns the `data_vector` of all mappers, a 1D vector whose values are solved for by the simultaneous - linear equations constructed by this object. The object is described in full in the method `data_vector`. - - This method is used to compute part of the `data_vector` if there are also linear function list objects - in the inversion, and is separated into a separate method to enable preloading of the mapper `data_vector`. - """ - - if not self.has(cls=AbstractMapper): - return - - data_vector = np.zeros(self.total_params) - - mapper_list = self.cls_list_from(cls=AbstractMapper) - mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) - - for i in range(len(mapper_list)): - mapper = mapper_list[i] - param_range = mapper_param_range_list[i] - - operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( - mapping_matrix=mapper.mapping_matrix, mask=self.mask, xp=self._xp - ) - - data_vector_mapper = ( - inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=operated_mapping_matrix, - image=self.data, - noise_map=self.noise_map, - ) - ) - - data_vector[param_range[0] : param_range[1],] = data_vector_mapper - - return data_vector - @cached_property def data_vector(self) -> np.ndarray: """ @@ -111,53 +73,6 @@ def data_vector(self) -> np.ndarray: noise_map=self.noise_map.array, ) - @property - def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: - """ - Returns the diagonal regions of the `curvature_matrix`, a 2D matrix which uses the mappings between the data - and the linear objects to construct the simultaneous linear equations. The object is described in full in - the method `curvature_matrix`. - - This method computes the diagonal entries of all mapper objects in the `curvature_matrix`. It is separate from - other calculations to enable preloading of this calculation. - """ - - if not self.has(cls=AbstractMapper): - return None - - curvature_matrix = np.zeros((self.total_params, self.total_params)) - - mapper_list = self.cls_list_from(cls=AbstractMapper) - mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) - - for i in range(len(mapper_list)): - mapper_i = mapper_list[i] - mapper_param_range_i = mapper_param_range_list[i] - - operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( - mapping_matrix=mapper_i.mapping_matrix, mask=self.mask, xp=self._xp - ) - - diag = inversion_util.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=operated_mapping_matrix, - noise_map=self.noise_map, - settings=self.settings, - add_to_curvature_diag=True, - no_regularization_index_list=self.no_regularization_index_list, - xp=self._xp, - ) - - curvature_matrix[ - mapper_param_range_i[0] : mapper_param_range_i[1], - mapper_param_range_i[0] : mapper_param_range_i[1], - ] = diag - - curvature_matrix = inversion_util.curvature_matrix_mirrored_from( - curvature_matrix=curvature_matrix, xp=self._xp - ) - - return curvature_matrix - @cached_property def curvature_matrix(self): """ diff --git a/autoarray/inversion/inversion/interferometer/mapping.py b/autoarray/inversion/inversion/interferometer/mapping.py index e516bf0d..f1455b93 100644 --- a/autoarray/inversion/inversion/interferometer/mapping.py +++ b/autoarray/inversion/inversion/interferometer/mapping.py @@ -88,12 +88,14 @@ def curvature_matrix(self) -> np.ndarray: real_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from( mapping_matrix=self.operated_mapping_matrix.real, noise_map=self.noise_map.real, + settings=self.settings, xp=self._xp, ) imag_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from( mapping_matrix=self.operated_mapping_matrix.imag, noise_map=self.noise_map.imag, + settings=self.settings, xp=self._xp, ) diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 854c331b..7d2e6400 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -78,11 +78,11 @@ def curvature_matrix_mirrored_from(curvature_matrix: np.ndarray, xp=np) -> np.nd def curvature_matrix_via_mapping_matrix_from( - mapping_matrix: np.ndarray, - noise_map: np.ndarray, + mapping_matrix: "np.ndarray", + noise_map: "np.ndarray", add_to_curvature_diag: bool = False, no_regularization_index_list: Optional[List] = None, - settings: SettingsInversion = SettingsInversion(), + settings: "SettingsInversion" = SettingsInversion(), xp=np, ) -> np.ndarray: """ @@ -97,10 +97,26 @@ def curvature_matrix_via_mapping_matrix_from( noise_map Flattened 1D array of the noise-map used by the inversion during the fit. """ - array = mapping_matrix / noise_map[:, None] - curvature_matrix = xp.dot(array.T, array) - - if add_to_curvature_diag and len(no_regularization_index_list) > 0: + # NumPy path: keep it simple + stable + if xp is np: + A = mapping_matrix / noise_map[:, None] + curvature_matrix = xp.dot(A.T, A) + else: + # Choose compute dtype + + compute_dtype = xp.float32 if settings.use_mixed_precision else xp.float64 + out_dtype = xp.float64 # always return float64 for downstream stability + + A = mapping_matrix + w = (1.0 / noise_map).astype(compute_dtype) + A = A * w[:, None] + curvature_matrix = xp.dot(A.T, A).astype(out_dtype) + + if ( + add_to_curvature_diag + and no_regularization_index_list + and len(no_regularization_index_list) > 0 + ): curvature_matrix = curvature_matrix_with_added_to_diag_from( curvature_matrix=curvature_matrix, value=settings.no_regularization_add_to_curvature_diag_value, diff --git a/autoarray/inversion/inversion/settings.py b/autoarray/inversion/inversion/settings.py index 73494226..2324c949 100644 --- a/autoarray/inversion/inversion/settings.py +++ b/autoarray/inversion/inversion/settings.py @@ -10,6 +10,7 @@ class SettingsInversion: def __init__( self, + use_mixed_precision: bool = False, use_positive_only_solver: Optional[bool] = None, positive_only_uses_p_initial: Optional[bool] = None, use_border_relocator: Optional[bool] = None, @@ -24,6 +25,12 @@ def __init__( Parameters ---------- + use_mixed_precision + If `True`, the linear algebra calculations of the inversion are performed using single precision on a + targeted subset of functions which provide significant speed up when using a GPU (x4), reduces VRAM + use and are expected to have minimal impact on the accuracy of the results. If `False`, all linear algebra + calculations are performed using double precision, which is the default and is more accurate but + slower on a GPU. use_positive_only_solver Whether to use a positive-only linear system solver, which requires that every reconstructed value is positive but is computationally much slower than the default solver (which allows for positive and @@ -41,6 +48,7 @@ def __init__( For an interferometer inversion using the linear operators method, sets the maximum number of iterations of the solver (this input does nothing for dataset data and other interferometer methods). """ + self.use_mixed_precision = use_mixed_precision self._use_positive_only_solver = use_positive_only_solver self._positive_only_uses_p_initial = positive_only_uses_p_initial self._use_border_relocator = use_border_relocator diff --git a/autoarray/inversion/linear_obj/func_list.py b/autoarray/inversion/linear_obj/func_list.py index b329a0c2..82393835 100644 --- a/autoarray/inversion/linear_obj/func_list.py +++ b/autoarray/inversion/linear_obj/func_list.py @@ -7,6 +7,7 @@ from autoarray.inversion.linear_obj.neighbors import Neighbors from autoarray.inversion.linear_obj.unique_mappings import UniqueMappings from autoarray.inversion.regularization.abstract import AbstractRegularization +from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.type import Grid1D2DLike @@ -15,6 +16,7 @@ def __init__( self, grid: Grid1D2DLike, regularization: Optional[AbstractRegularization], + settings=SettingsInversion(), xp=np, ): """ @@ -45,6 +47,7 @@ def __init__( super().__init__(regularization=regularization, xp=xp) self.grid = grid + self.settings = settings @cached_property def neighbors(self) -> Neighbors: diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index 8c8ed067..7a29c61a 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -11,6 +11,7 @@ from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids from autoarray.inversion.regularization.abstract import AbstractRegularization +from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.mesh.abstract_2d import Abstract2DMesh @@ -25,6 +26,7 @@ def __init__( mapper_grids: MapperGrids, regularization: Optional[AbstractRegularization], border_relocator: BorderRelocator, + settings: SettingsInversion = SettingsInversion(), preloads=None, xp=np, ): @@ -90,6 +92,7 @@ def __init__( self.border_relocator = border_relocator self.mapper_grids = mapper_grids self.preloads = preloads + self.settings = settings @property def params(self) -> int: @@ -265,6 +268,7 @@ def mapping_matrix(self) -> np.ndarray: total_mask_pixels=self.over_sampler.mask.pixels_in_mask, slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index, sub_fraction=self.over_sampler.sub_fraction.array, + use_mixed_precision=self.settings.use_mixed_precision, xp=self._xp, ) diff --git a/autoarray/inversion/pixelization/mappers/factory.py b/autoarray/inversion/pixelization/mappers/factory.py index 94ae0d3e..a7718e15 100644 --- a/autoarray/inversion/pixelization/mappers/factory.py +++ b/autoarray/inversion/pixelization/mappers/factory.py @@ -4,6 +4,7 @@ from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray.inversion.regularization.abstract import AbstractRegularization +from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay @@ -13,6 +14,7 @@ def mapper_from( mapper_grids: MapperGrids, regularization: Optional[AbstractRegularization], border_relocator: Optional[BorderRelocator] = None, + settings=SettingsInversion(), preloads=None, xp=np, ): @@ -53,6 +55,8 @@ def mapper_from( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, + settings=settings, + preloads=preloads, xp=xp, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular): @@ -60,6 +64,8 @@ def mapper_from( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, + settings=settings, + preloads=preloads, xp=xp, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay): @@ -67,6 +73,7 @@ def mapper_from( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, + settings=settings, preloads=preloads, xp=xp, ) diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index 4474335c..385f2190 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -548,6 +548,7 @@ def mapping_matrix_from( total_mask_pixels: int, slim_index_for_sub_slim_index: np.ndarray, sub_fraction: np.ndarray, + use_mixed_precision: bool = False, xp=np, ) -> np.ndarray: """ @@ -621,39 +622,56 @@ def mapping_matrix_from( sub_fraction The fractional area each sub-pixel takes up in an pixel. """ + M_sub, B = pix_indexes_for_sub_slim_index.shape - M = total_mask_pixels - S = pixels + M = int(total_mask_pixels) + S = int(pixels) + + # Indices always int32 + pix_idx = xp.asarray(pix_indexes_for_sub_slim_index, dtype=xp.int32) + pix_size = xp.asarray(pix_size_for_sub_slim_index, dtype=xp.int32) + slim_parent = xp.asarray(slim_index_for_sub_slim_index, dtype=xp.int32) + + # Everything else computed in float64 + w64 = xp.asarray(pix_weights_for_sub_slim_index, dtype=xp.float64) + frac64 = xp.asarray(sub_fraction, dtype=xp.float64) + + # Output dtype only (big allocation) + out_dtype = xp.float32 if use_mixed_precision else xp.float64 # 1) Flatten - flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,) - flat_w = pix_weights_for_sub_slim_index.reshape(-1) # (M_sub*B,) - flat_parent = xp.repeat(slim_index_for_sub_slim_index, B) # (M_sub*B,) - flat_count = xp.repeat(pix_size_for_sub_slim_index, B) # (M_sub*B,) + flat_pixidx = pix_idx.reshape(-1) # (M_sub*B,) + flat_w = w64.reshape(-1) # float64 + flat_parent = xp.repeat(slim_parent, B) # int32 + flat_count = xp.repeat(pix_size, B) # int32 - # 2) Build valid mask: k < pix_size[i] - k = xp.tile(xp.arange(B), M_sub) # (M_sub*B,) - valid = k < flat_count # (M_sub*B,) + # 2) valid mask: k < pix_size[i] + k = xp.tile(xp.arange(B, dtype=xp.int32), M_sub) + valid = k < flat_count - # 3) Zero out invalid weights - flat_w = flat_w * valid.astype(flat_w.dtype) + # 3) Zero out invalid weights (float64) + flat_w = flat_w * valid.astype(xp.float64) # 4) Redirect -1 indices to extra bin S OUT = S flat_pixidx = xp.where(flat_pixidx < 0, OUT, flat_pixidx) - # 5) Multiply by sub_fraction of the slim row - flat_frac = xp.take(sub_fraction, flat_parent, axis=0) # (M_sub*B,) - flat_contrib = flat_w * flat_frac # (M_sub*B,) + # 5) Multiply by sub_fraction of the slim row (float64) + flat_frac = xp.take(frac64, flat_parent, axis=0) + flat_contrib64 = flat_w * flat_frac + + # 6) Scatter into (M × (S+1)) (destination float32 or float64) + mat = xp.zeros((M, S + 1), dtype=out_dtype) + + # Cast only at the write (keeps upstream math float64) + flat_contrib_out = flat_contrib64.astype(out_dtype) - # 6) Scatter into (M × (S+1)), summing duplicates - mat = xp.zeros((M, S + 1), dtype=flat_contrib.dtype) if xp.__name__.startswith("jax"): - mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib) + mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib_out) else: - xp.add.at(mat, (flat_parent, flat_pixidx), flat_contrib) + xp.add.at(mat, (flat_parent, flat_pixidx), flat_contrib_out) - # 7) Drop the extra column and return + # 7) Drop extra column return mat[:, :S] diff --git a/autoarray/operators/mock/mock_psf.py b/autoarray/operators/mock/mock_psf.py index c5231469..7f6d2d03 100644 --- a/autoarray/operators/mock/mock_psf.py +++ b/autoarray/operators/mock/mock_psf.py @@ -5,5 +5,7 @@ class MockPSF: def __init__(self, operated_mapping_matrix=None): self.operated_mapping_matrix = operated_mapping_matrix - def convolved_mapping_matrix_from(self, mapping_matrix, mask, xp=np): + def convolved_mapping_matrix_from( + self, mapping_matrix, mask, use_mixed_precision=False, xp=np + ): return self.operated_mapping_matrix diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index abc74dc3..4ccc1845 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -533,6 +533,7 @@ def mapping_matrix_native_from( mask: "Mask2D", blurring_mapping_matrix: Optional[np.ndarray] = None, blurring_mask: Optional["Mask2D"] = None, + use_mixed_precision: bool = False, xp=np, ) -> np.ndarray: """ @@ -558,6 +559,10 @@ def mapping_matrix_native_from( Mask defining the blurring region pixels. Must be provided if `blurring_mapping_matrix` is given and `slim_to_native_blurring_tuple` is not already cached. + use_mixed_precision + If True, the mapping matrices are cast to single precision (float32) to + speed up GPU computations and reduce VRAM usage. If False, double precision + (float64) is used for maximum accuracy. Returns ------- @@ -566,37 +571,50 @@ def mapping_matrix_native_from( Contains contributions from both the main mapping matrix and, if provided, the blurring mapping matrix. """ + dtype_native = xp.float32 if use_mixed_precision else xp.float64 + n_src = mapping_matrix.shape[1] - # Allocate full native grid (ny, nx, n_src) - mapping_matrix_native = xp.zeros( - mask.shape + (n_src,), dtype=mapping_matrix.dtype + mapping_matrix_native = xp.zeros(mask.shape + (n_src,), dtype=dtype_native) + + # Cast inputs to the target dtype to avoid implicit up/downcasts inside scatter + mm = ( + mapping_matrix + if mapping_matrix.dtype == dtype_native + else xp.asarray(mapping_matrix, dtype=dtype_native) ) - # Scatter main mapping matrix into native cube if xp.__name__.startswith("jax"): mapping_matrix_native = mapping_matrix_native.at[ mask.slim_to_native_tuple - ].set(mapping_matrix) + ].set(mm) else: - mapping_matrix_native[mask.slim_to_native_tuple] = mapping_matrix - - # Optionally scatter blurring mapping matrix + mapping_matrix_native[mask.slim_to_native_tuple] = np.asarray(mm) if blurring_mapping_matrix is not None: + bm = blurring_mapping_matrix + if getattr(bm, "dtype", None) != dtype_native: + bm = xp.asarray(bm, dtype=dtype_native) if xp.__name__.startswith("jax"): mapping_matrix_native = mapping_matrix_native.at[ blurring_mask.slim_to_native_tuple - ].set(blurring_mapping_matrix) + ].set(bm) else: - mapping_matrix_native[blurring_mask.slim_to_native_tuple] = ( - blurring_mapping_matrix + mapping_matrix_native[blurring_mask.slim_to_native_tuple] = np.asarray( + bm ) return mapping_matrix_native - def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np): + def convolved_image_from( + self, + image, + blurring_image, + jax_method="direct", + use_mixed_precision: bool = False, + xp=np, + ): """ Convolve an input masked image with this PSF. @@ -649,41 +667,53 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np ) import jax + import jax.numpy as jnp + from autoarray.structures.arrays.uniform_2d import Array2D - if self.fft_shape is None: + # FFT path dtypes (JAX only) + fft_real_dtype = jnp.float32 if use_mixed_precision else jnp.float64 + fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128 + if self.fft_shape is None: + # Shapes computed on the fly full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=image.mask) - fft_psf = xp.fft.rfft2(self.stored_native.array, s=fft_shape, axes=(0, 1)) + + # Compute PSF FFT on the fly in the chosen precision + psf_native = jnp.asarray(self.stored_native.array, dtype=fft_real_dtype) + fft_psf = xp.fft.rfft2(psf_native, s=fft_shape, axes=(0, 1)).astype( + fft_complex_dtype + ) image_shape_original = image.shape_native + # Resize/pad the images as before image = image.resized_from(new_shape=fft_shape, mask_pad_value=1) if blurring_image is not None: blurring_image = blurring_image.resized_from( new_shape=fft_shape, mask_pad_value=1 ) - else: - + # Cached shapes/state fft_shape = self.fft_shape full_shape = self.full_shape mask_shape = self.mask_shape - fft_psf = self.fft_psf - # start with native image padded with zeros - image_both_native = xp.zeros(image.mask.shape, dtype=image.dtype) + # Use cached PSF FFT but ensure it matches chosen precision. + # IMPORTANT: casting here may create an extra buffer if self.fft_psf is complex128. + # Best practice is to cache a complex64 version on the object when MP is enabled. + fft_psf = jnp.asarray(self.fft_psf, dtype=fft_complex_dtype) + + # Build combined native image in the FFT dtype + image_both_native = xp.zeros(image.mask.shape, dtype=fft_real_dtype) image_both_native = image_both_native.at[image.mask.slim_to_native_tuple].set( - image.array + jnp.asarray(image.array, dtype=fft_real_dtype) ) - # add blurring contribution if provided if blurring_image is not None: - image_both_native = image_both_native.at[ blurring_image.mask.slim_to_native_tuple - ].set(blurring_image.array) - + ].set(jnp.asarray(blurring_image.array, dtype=fft_real_dtype)) else: warnings.warn( "No blurring_image provided. Only the direct image will be convolved. " @@ -698,25 +728,21 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np fft_psf * fft_image_native, s=fft_shape, axes=(0, 1) ) - out_shape_full = mask_shape - # Crop back to mask_shape start_indices = tuple( (full_size - out_size) // 2 for full_size, out_size in zip(full_shape, mask_shape) ) - blurred_image_native = jax.lax.dynamic_slice( - blurred_image_full, start_indices, out_shape_full + blurred_image_full, start_indices, mask_shape ) - blurred_image = Array2D( - values=blurred_image_native[image.mask.slim_to_native_tuple], - mask=image.mask, - ) + # Return slim form; optionally cast for downstream stability + blurred_slim = blurred_image_native[image.mask.slim_to_native_tuple] - if self.fft_shape is None: + blurred_image = Array2D(values=blurred_slim, mask=image.mask) + if self.fft_shape is None: blurred_image = blurred_image.resized_from( new_shape=image_shape_original, mask_pad_value=0 ) @@ -730,6 +756,7 @@ def convolved_mapping_matrix_from( blurring_mapping_matrix=None, blurring_mask: Optional[Mask2D] = None, jax_method="direct", + use_mixed_precision: bool = False, xp=np, ): """ @@ -770,12 +797,19 @@ def convolved_mapping_matrix_from( Mapping matrix for the blurring region, outside the mask core. jax_method : str Backend passed to real-space convolution if ``use_fft=False``. + use_mixed_precision + If `True`, the FFT is performed using single precision, which provide significant speed up when using a + GPU (x4), reduces VRAM use and is expected to have minimal impact on the accuracy of the results. If `False`, + the FFT is performed using double precision, which is the default and is more accurate but slower on a GPU. Returns ------- ndarray of shape (N_pix, N_src) Convolved mapping matrix in slim form. """ + # ------------------------------------------------------------------------- + # NumPy path unchanged + # ------------------------------------------------------------------------- if xp is np: return self.convolved_mapping_matrix_via_real_space_np_from( mapping_matrix=mapping_matrix, @@ -785,6 +819,9 @@ def convolved_mapping_matrix_from( xp=xp, ) + # ------------------------------------------------------------------------- + # Non-FFT JAX path unchanged + # ------------------------------------------------------------------------- if not self.use_fft: return self.convolved_mapping_matrix_via_real_space_from( mapping_matrix=mapping_matrix, @@ -796,34 +833,50 @@ def convolved_mapping_matrix_from( ) import jax + import jax.numpy as jnp + # ------------------------------------------------------------------------- + # Validate cached FFT shapes / state + # ------------------------------------------------------------------------- if self.fft_shape is None: - full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=mask) - raise ValueError( f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n" f"Expected mapping matrix padded to match FFT shape of PSF.\n" f"PSF fft_shape: {fft_shape}, mask shape: {mask.shape}, " f"mapping_matrix shape: {getattr(mapping_matrix, 'shape', 'unknown')}." ) - else: - fft_shape = self.fft_shape full_shape = self.full_shape mask_shape = self.mask_shape fft_psf_mapping = self.fft_psf_mapping + # ------------------------------------------------------------------------- + # Mixed precision dtypes (JAX only) + # ------------------------------------------------------------------------- + fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128 + + # Ensure PSF FFT dtype matches the FFT path + fft_psf_mapping = jnp.asarray(fft_psf_mapping, dtype=fft_complex_dtype) + + # ------------------------------------------------------------------------- + # Build native cube in the FFT dtype (THIS IS THE KEY) + # This relies on mapping_matrix_native_from honoring the use_mixed_precision + # kwarg when constructing the native mapping matrix. + # ------------------------------------------------------------------------- mapping_matrix_native = self.mapping_matrix_native_from( mapping_matrix=mapping_matrix, mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, blurring_mask=blurring_mask, + use_mixed_precision=use_mixed_precision, xp=xp, ) + # ------------------------------------------------------------------------- # FFT convolution + # ------------------------------------------------------------------------- fft_mapping_matrix_native = xp.fft.rfft2( mapping_matrix_native, s=fft_shape, axes=(0, 1) ) @@ -833,7 +886,9 @@ def convolved_mapping_matrix_from( axes=(0, 1), ) - # crop back + # ------------------------------------------------------------------------- + # Crop back to mask-shape + # ------------------------------------------------------------------------- start_indices = tuple( (full_size - out_size) // 2 for full_size, out_size in zip(full_shape, mask_shape) @@ -846,8 +901,10 @@ def convolved_mapping_matrix_from( out_shape_full, ) - # return slim form - return blurred_mapping_matrix_native[mask.slim_to_native_tuple] + # Return slim form + blurred_slim = blurred_mapping_matrix_native[mask.slim_to_native_tuple] + + return blurred_slim def rescaled_with_odd_dimensions_from( self, rescale_factor: float, normalize: bool = False diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index d0003e66..e45e6e41 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -290,9 +290,6 @@ def test__inversion_imaging__compare_mapping_and_sparse_operator_values( settings=aa.SettingsInversion(), ) - assert inversion_sparse_operator._curvature_matrix_mapper_diag == pytest.approx( - inversion_mapping._curvature_matrix_mapper_diag, 1.0e-4 - ) assert inversion_sparse_operator.reconstruction == pytest.approx( inversion_mapping.reconstruction, 1.0e-4 )