Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,10 @@ def apply_sparse_operator(
Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
"""

logger.info(
"IMAGING - Setting Up Sparse Operator For low Memory Pixelizations."
)

sparse_operator = (
inversion_imaging_util.ImagingSparseOperator.from_noise_map_and_psf(
data=self.data,
Expand Down
2 changes: 1 addition & 1 deletion autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def apply_sparse_operator(
if nufft_precision_operator is None:

logger.info(
"INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, CPU run times may exceed hours."
"INTERFEROMETER - Computing NUFFT Precision Operator; runtime scales with visibility count and mask resolution, CPU run times may exceed hours."
)

nufft_precision_operator = self.psf_precision_operator_from(
Expand Down
1 change: 1 addition & 0 deletions autoarray/inversion/inversion/imaging/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
self.psf.convolved_mapping_matrix_from(
mapping_matrix=linear_obj.mapping_matrix,
mask=self.mask,
use_mixed_precision=self.settings.use_mixed_precision,
xp=self._xp,
)
Comment on lines 95 to 100
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixed precision is forwarded for operated_mapping_matrix_list, but other code paths in this class still call psf.convolved_mapping_matrix_from(...) without use_mixed_precision (e.g. linear_func_operated_mapping_matrix_dict and mapper_operated_mapping_matrix_dict). This can lead to inconsistent dtypes/performance depending on which path is exercised (notably in sparse inversion code that uses those dicts). Ensure all internal PSF-convolution call sites in this class pass use_mixed_precision=self.settings.use_mixed_precision.

Copilot uses AI. Check for mistakes.
if linear_obj.operated_mapping_matrix_override is None
Expand Down
85 changes: 0 additions & 85 deletions autoarray/inversion/inversion/imaging/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,44 +53,6 @@ def __init__(
xp=xp,
)

@property
def _data_vector_mapper(self) -> np.ndarray:
"""
Returns the `data_vector` of all mappers, a 1D vector whose values are solved for by the simultaneous
linear equations constructed by this object. The object is described in full in the method `data_vector`.

This method is used to compute part of the `data_vector` if there are also linear function list objects
in the inversion, and is separated into a separate method to enable preloading of the mapper `data_vector`.
"""

if not self.has(cls=AbstractMapper):
return

data_vector = np.zeros(self.total_params)

mapper_list = self.cls_list_from(cls=AbstractMapper)
mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)

for i in range(len(mapper_list)):
mapper = mapper_list[i]
param_range = mapper_param_range_list[i]

operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
mapping_matrix=mapper.mapping_matrix, mask=self.mask, xp=self._xp
)

data_vector_mapper = (
inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from(
blurred_mapping_matrix=operated_mapping_matrix,
image=self.data,
noise_map=self.noise_map,
)
)

data_vector[param_range[0] : param_range[1],] = data_vector_mapper

return data_vector

@cached_property
def data_vector(self) -> np.ndarray:
"""
Expand All @@ -111,53 +73,6 @@ def data_vector(self) -> np.ndarray:
noise_map=self.noise_map.array,
)

@property
def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]:
"""
Returns the diagonal regions of the `curvature_matrix`, a 2D matrix which uses the mappings between the data
and the linear objects to construct the simultaneous linear equations. The object is described in full in
the method `curvature_matrix`.

This method computes the diagonal entries of all mapper objects in the `curvature_matrix`. It is separate from
other calculations to enable preloading of this calculation.
"""

if not self.has(cls=AbstractMapper):
return None

curvature_matrix = np.zeros((self.total_params, self.total_params))

mapper_list = self.cls_list_from(cls=AbstractMapper)
mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)

for i in range(len(mapper_list)):
mapper_i = mapper_list[i]
mapper_param_range_i = mapper_param_range_list[i]

operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
mapping_matrix=mapper_i.mapping_matrix, mask=self.mask, xp=self._xp
)

diag = inversion_util.curvature_matrix_via_mapping_matrix_from(
mapping_matrix=operated_mapping_matrix,
noise_map=self.noise_map,
settings=self.settings,
add_to_curvature_diag=True,
no_regularization_index_list=self.no_regularization_index_list,
xp=self._xp,
)

curvature_matrix[
mapper_param_range_i[0] : mapper_param_range_i[1],
mapper_param_range_i[0] : mapper_param_range_i[1],
] = diag

curvature_matrix = inversion_util.curvature_matrix_mirrored_from(
curvature_matrix=curvature_matrix, xp=self._xp
)

return curvature_matrix

@cached_property
def curvature_matrix(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions autoarray/inversion/inversion/interferometer/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ def curvature_matrix(self) -> np.ndarray:
real_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from(
mapping_matrix=self.operated_mapping_matrix.real,
noise_map=self.noise_map.real,
settings=self.settings,
xp=self._xp,
)

imag_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from(
mapping_matrix=self.operated_mapping_matrix.imag,
noise_map=self.noise_map.imag,
settings=self.settings,
xp=self._xp,
)

Expand Down
30 changes: 23 additions & 7 deletions autoarray/inversion/inversion/inversion_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def curvature_matrix_mirrored_from(curvature_matrix: np.ndarray, xp=np) -> np.nd


def curvature_matrix_via_mapping_matrix_from(
mapping_matrix: np.ndarray,
noise_map: np.ndarray,
mapping_matrix: "np.ndarray",
noise_map: "np.ndarray",
add_to_curvature_diag: bool = False,
no_regularization_index_list: Optional[List] = None,
settings: SettingsInversion = SettingsInversion(),
settings: "SettingsInversion" = SettingsInversion(),
xp=np,
) -> np.ndarray:
Comment on lines 80 to 87
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curvature_matrix_via_mapping_matrix_from now takes both settings and a separate use_mixed_precision flag. This makes it easy for call sites to pass settings but forget use_mixed_precision, silently disabling mixed precision. Consider deriving mixed-precision behavior from settings.use_mixed_precision inside this function (or make use_mixed_precision: Optional[bool]=None and default it from settings) to avoid partial propagation bugs.

Copilot uses AI. Check for mistakes.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

"""
Expand All @@ -97,10 +97,26 @@ def curvature_matrix_via_mapping_matrix_from(
noise_map
Flattened 1D array of the noise-map used by the inversion during the fit.
"""
array = mapping_matrix / noise_map[:, None]
curvature_matrix = xp.dot(array.T, array)

if add_to_curvature_diag and len(no_regularization_index_list) > 0:
# NumPy path: keep it simple + stable
if xp is np:
A = mapping_matrix / noise_map[:, None]
curvature_matrix = xp.dot(A.T, A)
else:
# Choose compute dtype

compute_dtype = xp.float32 if settings.use_mixed_precision else xp.float64
out_dtype = xp.float64 # always return float64 for downstream stability

A = mapping_matrix
w = (1.0 / noise_map).astype(compute_dtype)
Comment on lines +110 to +111
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the JAX/non-NumPy path, compute_dtype is selected but mapping_matrix is not explicitly cast to it (A = mapping_matrix). If mapping_matrix arrives as float64, operations will upcast and defeat mixed precision. Cast mapping_matrix (and ideally noise_map) to compute_dtype before forming A so the dot product actually runs in the intended precision.

Suggested change
A = mapping_matrix
w = (1.0 / noise_map).astype(compute_dtype)
A = mapping_matrix.astype(compute_dtype)
noise = noise_map.astype(compute_dtype)
w = 1.0 / noise

Copilot uses AI. Check for mistakes.
A = A * w[:, None]
curvature_matrix = xp.dot(A.T, A).astype(out_dtype)

if (
add_to_curvature_diag
and no_regularization_index_list
and len(no_regularization_index_list) > 0
):
curvature_matrix = curvature_matrix_with_added_to_diag_from(
curvature_matrix=curvature_matrix,
value=settings.no_regularization_add_to_curvature_diag_value,
Expand Down
8 changes: 8 additions & 0 deletions autoarray/inversion/inversion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class SettingsInversion:
def __init__(
self,
use_mixed_precision: bool = False,
use_positive_only_solver: Optional[bool] = None,
positive_only_uses_p_initial: Optional[bool] = None,
use_border_relocator: Optional[bool] = None,
Expand All @@ -24,6 +25,12 @@ def __init__(

Parameters
----------
use_mixed_precision
If `True`, the linear algebra calculations of the inversion are performed using single precision on a
targeted subset of functions which provide significant speed up when using a GPU (x4), reduces VRAM
use and are expected to have minimal impact on the accuracy of the results. If `False`, all linear algebra
calculations are performed using double precision, which is the default and is more accurate but
slower on a GPU.
use_positive_only_solver
Whether to use a positive-only linear system solver, which requires that every reconstructed value is
positive but is computationally much slower than the default solver (which allows for positive and
Expand All @@ -41,6 +48,7 @@ def __init__(
For an interferometer inversion using the linear operators method, sets the maximum number of iterations
of the solver (this input does nothing for dataset data and other interferometer methods).
"""
self.use_mixed_precision = use_mixed_precision
self._use_positive_only_solver = use_positive_only_solver
self._positive_only_uses_p_initial = positive_only_uses_p_initial
self._use_border_relocator = use_border_relocator
Expand Down
3 changes: 3 additions & 0 deletions autoarray/inversion/linear_obj/func_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from autoarray.inversion.linear_obj.neighbors import Neighbors
from autoarray.inversion.linear_obj.unique_mappings import UniqueMappings
from autoarray.inversion.regularization.abstract import AbstractRegularization
from autoarray.inversion.inversion.settings import SettingsInversion
from autoarray.type import Grid1D2DLike


Expand All @@ -15,6 +16,7 @@ def __init__(
self,
grid: Grid1D2DLike,
regularization: Optional[AbstractRegularization],
settings=SettingsInversion(),
xp=np,
):
"""
Expand Down Expand Up @@ -45,6 +47,7 @@ def __init__(
super().__init__(regularization=regularization, xp=xp)

self.grid = grid
self.settings = settings

@cached_property
def neighbors(self) -> Neighbors:
Expand Down
4 changes: 4 additions & 0 deletions autoarray/inversion/pixelization/mappers/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids
from autoarray.inversion.regularization.abstract import AbstractRegularization
from autoarray.inversion.inversion.settings import SettingsInversion
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.grids.uniform_2d import Grid2D
from autoarray.structures.mesh.abstract_2d import Abstract2DMesh
Expand All @@ -25,6 +26,7 @@ def __init__(
mapper_grids: MapperGrids,
regularization: Optional[AbstractRegularization],
border_relocator: BorderRelocator,
settings: SettingsInversion = SettingsInversion(),
preloads=None,
xp=np,
):
Expand Down Expand Up @@ -90,6 +92,7 @@ def __init__(
self.border_relocator = border_relocator
self.mapper_grids = mapper_grids
self.preloads = preloads
self.settings = settings

@property
def params(self) -> int:
Expand Down Expand Up @@ -265,6 +268,7 @@ def mapping_matrix(self) -> np.ndarray:
total_mask_pixels=self.over_sampler.mask.pixels_in_mask,
slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index,
sub_fraction=self.over_sampler.sub_fraction.array,
use_mixed_precision=self.settings.use_mixed_precision,
xp=self._xp,
)

Expand Down
7 changes: 7 additions & 0 deletions autoarray/inversion/pixelization/mappers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
from autoarray.inversion.regularization.abstract import AbstractRegularization
from autoarray.inversion.inversion.settings import SettingsInversion
from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular
from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay
Expand All @@ -13,6 +14,7 @@ def mapper_from(
mapper_grids: MapperGrids,
regularization: Optional[AbstractRegularization],
border_relocator: Optional[BorderRelocator] = None,
settings=SettingsInversion(),
preloads=None,
xp=np,
):
Expand Down Expand Up @@ -53,20 +55,25 @@ def mapper_from(
mapper_grids=mapper_grids,
border_relocator=border_relocator,
regularization=regularization,
settings=settings,
preloads=preloads,
xp=xp,
)
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular):
return MapperRectangular(
mapper_grids=mapper_grids,
border_relocator=border_relocator,
regularization=regularization,
settings=settings,
preloads=preloads,
Comment on lines 64 to +68
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This factory now propagates settings into the Rectangular / RectangularUniform mapper constructors, but the Delaunay branch below still does not pass settings. That means caller-provided inversion settings (including use_mixed_precision) are ignored for MapperDelaunay. Update the Delaunay branch to pass settings=settings for consistent behavior across mapper types.

Copilot uses AI. Check for mistakes.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

xp=xp,
)
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay):
return MapperDelaunay(
mapper_grids=mapper_grids,
border_relocator=border_relocator,
regularization=regularization,
settings=settings,
preloads=preloads,
xp=xp,
)
56 changes: 37 additions & 19 deletions autoarray/inversion/pixelization/mappers/mapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ def mapping_matrix_from(
total_mask_pixels: int,
slim_index_for_sub_slim_index: np.ndarray,
sub_fraction: np.ndarray,
use_mixed_precision: bool = False,
xp=np,
) -> np.ndarray:
"""
Expand Down Expand Up @@ -621,39 +622,56 @@ def mapping_matrix_from(
sub_fraction
The fractional area each sub-pixel takes up in an pixel.
"""

M_sub, B = pix_indexes_for_sub_slim_index.shape
M = total_mask_pixels
S = pixels
M = int(total_mask_pixels)
S = int(pixels)

# Indices always int32
pix_idx = xp.asarray(pix_indexes_for_sub_slim_index, dtype=xp.int32)
pix_size = xp.asarray(pix_size_for_sub_slim_index, dtype=xp.int32)
slim_parent = xp.asarray(slim_index_for_sub_slim_index, dtype=xp.int32)

# Everything else computed in float64
w64 = xp.asarray(pix_weights_for_sub_slim_index, dtype=xp.float64)
frac64 = xp.asarray(sub_fraction, dtype=xp.float64)

# Output dtype only (big allocation)
out_dtype = xp.float32 if use_mixed_precision else xp.float64

# 1) Flatten
flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,)
flat_w = pix_weights_for_sub_slim_index.reshape(-1) # (M_sub*B,)
flat_parent = xp.repeat(slim_index_for_sub_slim_index, B) # (M_sub*B,)
flat_count = xp.repeat(pix_size_for_sub_slim_index, B) # (M_sub*B,)
flat_pixidx = pix_idx.reshape(-1) # (M_sub*B,)
flat_w = w64.reshape(-1) # float64
flat_parent = xp.repeat(slim_parent, B) # int32
flat_count = xp.repeat(pix_size, B) # int32

# 2) Build valid mask: k < pix_size[i]
k = xp.tile(xp.arange(B), M_sub) # (M_sub*B,)
valid = k < flat_count # (M_sub*B,)
# 2) valid mask: k < pix_size[i]
k = xp.tile(xp.arange(B, dtype=xp.int32), M_sub)
valid = k < flat_count

# 3) Zero out invalid weights
flat_w = flat_w * valid.astype(flat_w.dtype)
# 3) Zero out invalid weights (float64)
flat_w = flat_w * valid.astype(xp.float64)

# 4) Redirect -1 indices to extra bin S
OUT = S
flat_pixidx = xp.where(flat_pixidx < 0, OUT, flat_pixidx)

# 5) Multiply by sub_fraction of the slim row
flat_frac = xp.take(sub_fraction, flat_parent, axis=0) # (M_sub*B,)
flat_contrib = flat_w * flat_frac # (M_sub*B,)
# 5) Multiply by sub_fraction of the slim row (float64)
flat_frac = xp.take(frac64, flat_parent, axis=0)
flat_contrib64 = flat_w * flat_frac

# 6) Scatter into (M × (S+1)) (destination float32 or float64)
mat = xp.zeros((M, S + 1), dtype=out_dtype)

# Cast only at the write (keeps upstream math float64)
flat_contrib_out = flat_contrib64.astype(out_dtype)

# 6) Scatter into (M × (S+1)), summing duplicates
mat = xp.zeros((M, S + 1), dtype=flat_contrib.dtype)
if xp.__name__.startswith("jax"):
mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib)
mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib_out)
else:
xp.add.at(mat, (flat_parent, flat_pixidx), flat_contrib)
xp.add.at(mat, (flat_parent, flat_pixidx), flat_contrib_out)

# 7) Drop the extra column and return
# 7) Drop extra column
return mat[:, :S]


Expand Down
Loading
Loading