-
Notifications
You must be signed in to change notification settings - Fork 7
Feature/linalg mixed precision #205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d4d1994
e15260d
456905c
411ded5
92809b2
cfe1b89
5d0fab9
a50e5e6
88a2a24
9898d84
4e78a8c
3ed7e96
29b8118
6b20761
a78b864
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -78,11 +78,11 @@ def curvature_matrix_mirrored_from(curvature_matrix: np.ndarray, xp=np) -> np.nd | |||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def curvature_matrix_via_mapping_matrix_from( | ||||||||||||
| mapping_matrix: np.ndarray, | ||||||||||||
| noise_map: np.ndarray, | ||||||||||||
| mapping_matrix: "np.ndarray", | ||||||||||||
| noise_map: "np.ndarray", | ||||||||||||
| add_to_curvature_diag: bool = False, | ||||||||||||
| no_regularization_index_list: Optional[List] = None, | ||||||||||||
| settings: SettingsInversion = SettingsInversion(), | ||||||||||||
| settings: "SettingsInversion" = SettingsInversion(), | ||||||||||||
| xp=np, | ||||||||||||
| ) -> np.ndarray: | ||||||||||||
|
Comment on lines
80
to
87
|
||||||||||||
| """ | ||||||||||||
|
|
@@ -97,10 +97,26 @@ def curvature_matrix_via_mapping_matrix_from( | |||||||||||
| noise_map | ||||||||||||
| Flattened 1D array of the noise-map used by the inversion during the fit. | ||||||||||||
| """ | ||||||||||||
| array = mapping_matrix / noise_map[:, None] | ||||||||||||
| curvature_matrix = xp.dot(array.T, array) | ||||||||||||
|
|
||||||||||||
| if add_to_curvature_diag and len(no_regularization_index_list) > 0: | ||||||||||||
| # NumPy path: keep it simple + stable | ||||||||||||
| if xp is np: | ||||||||||||
| A = mapping_matrix / noise_map[:, None] | ||||||||||||
| curvature_matrix = xp.dot(A.T, A) | ||||||||||||
| else: | ||||||||||||
| # Choose compute dtype | ||||||||||||
|
|
||||||||||||
| compute_dtype = xp.float32 if settings.use_mixed_precision else xp.float64 | ||||||||||||
| out_dtype = xp.float64 # always return float64 for downstream stability | ||||||||||||
|
|
||||||||||||
| A = mapping_matrix | ||||||||||||
| w = (1.0 / noise_map).astype(compute_dtype) | ||||||||||||
|
Comment on lines
+110
to
+111
|
||||||||||||
| A = mapping_matrix | |
| w = (1.0 / noise_map).astype(compute_dtype) | |
| A = mapping_matrix.astype(compute_dtype) | |
| noise = noise_map.astype(compute_dtype) | |
| w = 1.0 / noise |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids | ||
| from autoarray.inversion.pixelization.border_relocator import BorderRelocator | ||
| from autoarray.inversion.regularization.abstract import AbstractRegularization | ||
| from autoarray.inversion.inversion.settings import SettingsInversion | ||
| from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular | ||
| from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform | ||
| from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay | ||
|
|
@@ -13,6 +14,7 @@ def mapper_from( | |
| mapper_grids: MapperGrids, | ||
| regularization: Optional[AbstractRegularization], | ||
| border_relocator: Optional[BorderRelocator] = None, | ||
| settings=SettingsInversion(), | ||
| preloads=None, | ||
| xp=np, | ||
| ): | ||
|
|
@@ -53,20 +55,25 @@ def mapper_from( | |
| mapper_grids=mapper_grids, | ||
| border_relocator=border_relocator, | ||
| regularization=regularization, | ||
| settings=settings, | ||
| preloads=preloads, | ||
| xp=xp, | ||
| ) | ||
| elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular): | ||
| return MapperRectangular( | ||
| mapper_grids=mapper_grids, | ||
| border_relocator=border_relocator, | ||
| regularization=regularization, | ||
| settings=settings, | ||
| preloads=preloads, | ||
|
Comment on lines
64
to
+68
|
||
| xp=xp, | ||
| ) | ||
| elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay): | ||
| return MapperDelaunay( | ||
| mapper_grids=mapper_grids, | ||
| border_relocator=border_relocator, | ||
| regularization=regularization, | ||
| settings=settings, | ||
| preloads=preloads, | ||
| xp=xp, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mixed precision is forwarded for
operated_mapping_matrix_list, but other code paths in this class still callpsf.convolved_mapping_matrix_from(...)withoutuse_mixed_precision(e.g.linear_func_operated_mapping_matrix_dictandmapper_operated_mapping_matrix_dict). This can lead to inconsistent dtypes/performance depending on which path is exercised (notably in sparse inversion code that uses those dicts). Ensure all internal PSF-convolution call sites in this class passuse_mixed_precision=self.settings.use_mixed_precision.