Skip to content

Feature/linalg mixed precision#205

Merged
Jammy2211 merged 15 commits intomainfrom
feature/linalg_mixed_precision
Feb 8, 2026
Merged

Feature/linalg mixed precision#205
Jammy2211 merged 15 commits intomainfrom
feature/linalg_mixed_precision

Conversation

@Jammy2211
Copy link
Owner

@Jammy2211 Jammy2211 commented Feb 8, 2026

This pull request introduces mixed precision support for linear algebra calculations in the inversion pipeline, allowing targeted computations to use single precision for significant speed improvements on GPUs, reduced VRAM usage, and minimal impact on accuracy.

The changes propagate the new use_mixed_precision setting throughout the codebase, update relevant functions and classes to accept and utilize this option, and refactor the mapping matrix and curvature matrix calculations to handle mixed precision correctly.

Mixed precision support and propagation:

  • Added the use_mixed_precision option to the SettingsInversion class, with documentation explaining its impact and default behavior. [1] [2] [3]
  • Updated constructors and factories for pixelization mappers and function lists to accept and propagate the settings object, enabling mixed precision to be passed through the inversion pipeline. [1] [2] [3] [4] [5] [6] [7] [8] [9]

Mapping matrix and curvature matrix refactoring:

  • Refactored mapping_matrix_from and related functions to support mixed precision, including careful handling of input/output dtypes, upstream math in float64, and destination dtype selection based on use_mixed_precision. [1] [2] [3] [4]
  • Updated curvature matrix calculations to use mixed precision when enabled, including explicit dtype selection and stable float64 output for downstream processing. [1] [2] [3]

Codebase cleanup and logging improvements:

  • Removed unused or redundant methods from inversion/imaging/mapping.py to streamline the codebase and ensure mixed precision is handled in the main property methods. [1] [2]
  • Improved logging messages in imaging and interferometer dataset processing for clarity regarding sparse operator and NUFFT precision operator setup. [1] [2]

Mock and test support:

  • Updated the mock PSF class to accept the use_mixed_precision parameter, ensuring compatibility with the new mixed precision logic in tests.

These changes collectively enable efficient mixed precision computation in the inversion pipeline, improve code clarity, and ensure the new option is consistently propagated and utilized throughout the codebase.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an opt-in mixed-precision mode (use_mixed_precision) to the inversion pipeline to reduce GPU memory usage and speed up selected linear algebra operations, propagating the setting through mapper/matrix construction and PSF-operated mapping-matrix paths.

Changes:

  • Introduces use_mixed_precision on SettingsInversion and propagates it through mapper creation and operated mapping-matrix generation.
  • Refactors mapping-matrix construction / scattering and Kernel2D FFT convolution paths to support float32 outputs when enabled.
  • Updates curvature-matrix computation to optionally compute in float32 (JAX path) while returning a stable dtype for downstream use; includes minor logging tweaks and test adjustments.

Reviewed changes

Copilot reviewed 13 out of 14 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
autoarray/inversion/inversion/settings.py Adds the use_mixed_precision setting and documents intended behavior.
autoarray/inversion/pixelization/mappers/abstract.py Stores settings on mappers and uses it to select mapping-matrix output dtype.
autoarray/inversion/pixelization/mappers/factory.py Passes settings into some mapper constructors (but currently misses Delaunay).
autoarray/inversion/pixelization/mappers/mapper_util.py Adds mixed-precision option to mapping-matrix construction (float64 compute, float32 output).
autoarray/structures/arrays/kernel_2d.py Adds mixed-precision support to FFT-based convolution for images/mapping matrices.
autoarray/inversion/inversion/imaging/abstract.py Forwards mixed precision into PSF-operated mapping matrix generation (some internal paths still unupdated).
autoarray/inversion/inversion/imaging/mapping.py Wires mixed-precision option into curvature-matrix computation call.
autoarray/inversion/inversion/inversion_util.py Adds mixed-precision controls to curvature-matrix computation (needs API/dtype fixes).
autoarray/operators/mock/mock_psf.py Updates mock PSF API to accept use_mixed_precision.
autoarray/dataset/imaging/dataset.py / autoarray/dataset/interferometer/dataset.py Logging message tweaks around sparse operator / NUFFT operator setup.
test_autoarray/inversion/inversion/test_factory.py Adjusts tests to reflect refactors/removals in curvature-diagonal internals.
autoarray/__init__.py Minor import whitespace change (currently adds trailing whitespace).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 95 to 100
self.psf.convolved_mapping_matrix_from(
mapping_matrix=linear_obj.mapping_matrix,
mask=self.mask,
use_mixed_precision=self.settings.use_mixed_precision,
xp=self._xp,
)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixed precision is forwarded for operated_mapping_matrix_list, but other code paths in this class still call psf.convolved_mapping_matrix_from(...) without use_mixed_precision (e.g. linear_func_operated_mapping_matrix_dict and mapper_operated_mapping_matrix_dict). This can lead to inconsistent dtypes/performance depending on which path is exercised (notably in sparse inversion code that uses those dicts). Ensure all internal PSF-convolution call sites in this class pass use_mixed_precision=self.settings.use_mixed_precision.

Copilot uses AI. Check for mistakes.
Comment on lines 80 to 88
def curvature_matrix_via_mapping_matrix_from(
mapping_matrix: np.ndarray,
noise_map: np.ndarray,
mapping_matrix: "np.ndarray",
noise_map: "np.ndarray",
add_to_curvature_diag: bool = False,
no_regularization_index_list: Optional[List] = None,
settings: SettingsInversion = SettingsInversion(),
settings: "SettingsInversion" = SettingsInversion(),
use_mixed_precision: bool = False,
xp=np,
) -> np.ndarray:
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curvature_matrix_via_mapping_matrix_from now takes both settings and a separate use_mixed_precision flag. This makes it easy for call sites to pass settings but forget use_mixed_precision, silently disabling mixed precision. Consider deriving mixed-precision behavior from settings.use_mixed_precision inside this function (or make use_mixed_precision: Optional[bool]=None and default it from settings) to avoid partial propagation bugs.

Copilot uses AI. Check for mistakes.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

Comment on lines +111 to +112
A = mapping_matrix
w = (1.0 / noise_map).astype(compute_dtype)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the JAX/non-NumPy path, compute_dtype is selected but mapping_matrix is not explicitly cast to it (A = mapping_matrix). If mapping_matrix arrives as float64, operations will upcast and defeat mixed precision. Cast mapping_matrix (and ideally noise_map) to compute_dtype before forming A so the dot product actually runs in the intended precision.

Suggested change
A = mapping_matrix
w = (1.0 / noise_map).astype(compute_dtype)
A = mapping_matrix.astype(compute_dtype)
noise = noise_map.astype(compute_dtype)
w = 1.0 / noise

Copilot uses AI. Check for mistakes.
Comment on lines 64 to +68
mapper_grids=mapper_grids,
border_relocator=border_relocator,
regularization=regularization,
settings=settings,
preloads=preloads,
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This factory now propagates settings into the Rectangular / RectangularUniform mapper constructors, but the Delaunay branch below still does not pass settings. That means caller-provided inversion settings (including use_mixed_precision) are ignored for MapperDelaunay. Update the Delaunay branch to pass settings=settings for consistent behavior across mapper types.

Copilot uses AI. Check for mistakes.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

Jammy2211 and others added 2 commits February 8, 2026 09:20
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI commented Feb 8, 2026

@Jammy2211 I've opened a new pull request, #206, to work on those changes. Once the pull request is ready, I'll request review from you.

Jammy2211 and others added 2 commits February 8, 2026 09:22
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI commented Feb 8, 2026

@Jammy2211 I've opened a new pull request, #207, to work on those changes. Once the pull request is ready, I'll request review from you.

Jammy2211 and others added 5 commits February 8, 2026 09:23
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Jammy2211 <23455639+Jammy2211@users.noreply.github.com>
Co-authored-by: Jammy2211 <23455639+Jammy2211@users.noreply.github.com>
Pass settings to MapperDelaunay in factory
Derive mixed precision from settings object to prevent propagation bugs
@Jammy2211 Jammy2211 merged commit 7f921f1 into main Feb 8, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants