From 9559139109500142594da7aa94cbac5d9741a57f Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sat, 17 Jan 2026 01:02:45 -0800 Subject: [PATCH 01/12] Add grid visualization to Dataset3d.show() with ncols and returnfig options --- src/quantem/core/datastructures/dataset3d.py | 60 ++++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/src/quantem/core/datastructures/dataset3d.py b/src/quantem/core/datastructures/dataset3d.py index 53d6d02c..98e7b7fd 100644 --- a/src/quantem/core/datastructures/dataset3d.py +++ b/src/quantem/core/datastructures/dataset3d.py @@ -107,24 +107,62 @@ def to_dataset2d(self): def show( self, - index: int = 0, + index: int | None = None, scalebar: ScalebarConfig | bool = True, title: str | None = None, + ncols: int = 4, + returnfig: bool = False, **kwargs, ): """ - Display a 2D slice of the 3D dataset. + Display 2D slices of the 3D dataset. Parameters ---------- - index : int - Index of the 2D slice to display (along axis 0). - scalebar: ScalebarConfig or bool - If True, displays scalebar - title: str - Title of Dataset - **kwargs: dict - Keyword arguments for show_2d + index : int | None + Index of the 2D slice to display. If None, shows all slices in a grid. + scalebar : ScalebarConfig or bool + If True, displays scalebar. + title : str | None + Title for the plot. If None and showing all, uses "Frame 0", "Frame 1", etc. + ncols : int + Maximum columns when showing all slices. Default: 4. + returnfig : bool + If True, returns (fig, axes). Default: False. + **kwargs : dict + Keyword arguments for show_2d (cmap, cbar, norm, etc.). + + Examples + -------- + >>> data.show() # show all frames in grid + >>> data.show(index=0) # show single frame + >>> data.show(ncols=3) # 3 columns + >>> fig, axes = data.show(returnfig=True) # get figure for customization """ + from quantem.core.visualization import show_2d + + if index is not None: + result = self[index].show(scalebar=scalebar, title=title, **kwargs) + return result if returnfig else None + + # Show all frames in a grid + n = self.shape[0] + nrows = (n + ncols - 1) // ncols + arrays = [] + titles = [] + for row in range(nrows): + row_arrays = [] + row_titles = [] + for col in range(ncols): + i = row * ncols + col + if i < n: + row_arrays.append(self.array[i]) + row_titles.append(f"Frame {i}" if title is None else f"{title} {i}") + else: + row_arrays.append(np.zeros_like(self.array[0])) + row_titles.append("") + arrays.append(row_arrays) + titles.append(row_titles) - return self[index].show(scalebar=scalebar, title=title, **kwargs) + result = show_2d(arrays, scalebar=scalebar, title=titles, **kwargs) + return result if returnfig else None From 0cb23cb495460b1ef4a1b8b4ac09e4e969ff5ea3 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sat, 17 Jan 2026 01:02:45 -0800 Subject: [PATCH 02/12] Add NumPy-style docstrings and enhance Dataset3d.show() - Add docstring with examples to from_array() - Add docstring with examples to from_shape() - Use "Frame N" as default title (appropriate for time series/tomography) - Add suptitle parameter with proper layout adjustment --- src/quantem/core/datastructures/dataset3d.py | 94 ++++++++++++++++++-- 1 file changed, 89 insertions(+), 5 deletions(-) diff --git a/src/quantem/core/datastructures/dataset3d.py b/src/quantem/core/datastructures/dataset3d.py index 98e7b7fd..4f36ef04 100644 --- a/src/quantem/core/datastructures/dataset3d.py +++ b/src/quantem/core/datastructures/dataset3d.py @@ -68,6 +68,50 @@ def from_array( units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", ) -> Self: + """Create a Dataset3d from a 3D array. + + Parameters + ---------- + array : NDArray | Any + 3D array with shape (n_frames, height, width) + name : str | None + Dataset name. Default: "3D dataset" + origin : NDArray | tuple | list | float | int | None + Origin for each dimension. Default: [0, 0, 0] + sampling : NDArray | tuple | list | float | int | None + Sampling for each dimension. Default: [1, 1, 1] + units : list[str] | tuple | list | None + Units for each dimension. Default: ["index", "pixels", "pixels"] + signal_units : str + Units for array values. Default: "arb. units" + + Returns + ------- + Dataset3d + + Examples + -------- + >>> import numpy as np + >>> from quantem.core.datastructures import Dataset3d + >>> arr = np.random.rand(10, 64, 64) + >>> data = Dataset3d.from_array(arr) + >>> data.shape + (10, 64, 64) + + With calibration: + + >>> data = Dataset3d.from_array( + ... arr, + ... sampling=[1, 0.1, 0.1], + ... units=["frame", "nm", "nm"], + ... ) + + Visualize: + + >>> data.show() # all frames in grid + >>> data.show(index=0) # single frame + >>> data.show(ncols=2) # 2 columns + """ array = ensure_valid_array(array, ndim=3) return cls( array=array, @@ -90,7 +134,37 @@ def from_shape( units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", ) -> Self: - """Create a new Dataset3d filled with a constant value.""" + """Create a Dataset3d filled with a constant value. + + Parameters + ---------- + shape : tuple[int, int, int] + Shape (n_frames, height, width) + name : str + Dataset name. Default: "constant 3D dataset" + fill_value : float + Value to fill array with. Default: 0.0 + origin : NDArray | tuple | list | float | int | None + Origin for each dimension + sampling : NDArray | tuple | list | float | int | None + Sampling for each dimension + units : list[str] | tuple | list | None + Units for each dimension + signal_units : str + Units for array values + + Returns + ------- + Dataset3d + + Examples + -------- + >>> data = Dataset3d.from_shape((10, 64, 64)) + >>> data.shape + (10, 64, 64) + >>> data.array.max() + 0.0 + """ array = np.full(shape, fill_value, dtype=np.float32) return cls.from_array( array=array, @@ -110,6 +184,7 @@ def show( index: int | None = None, scalebar: ScalebarConfig | bool = True, title: str | None = None, + suptitle: str | None = None, ncols: int = 4, returnfig: bool = False, **kwargs, @@ -124,7 +199,9 @@ def show( scalebar : ScalebarConfig or bool If True, displays scalebar. title : str | None - Title for the plot. If None and showing all, uses "Frame 0", "Frame 1", etc. + Title for the plot. If None, uses "Frame 0", "Frame 1", etc. + suptitle : str | None + Figure super title displayed above all subplots. ncols : int Maximum columns when showing all slices. Default: 4. returnfig : bool @@ -137,12 +214,16 @@ def show( >>> data.show() # show all frames in grid >>> data.show(index=0) # show single frame >>> data.show(ncols=3) # 3 columns + >>> data.show(suptitle="Diffraction patterns") # with super title >>> fig, axes = data.show(returnfig=True) # get figure for customization """ from quantem.core.visualization import show_2d if index is not None: - result = self[index].show(scalebar=scalebar, title=title, **kwargs) + # Handle negative index + actual_index = index if index >= 0 else self.shape[0] + index + default_title = title if title is not None else f"Frame {actual_index}" + result = self[index].show(scalebar=scalebar, title=default_title, **kwargs) return result if returnfig else None # Show all frames in a grid @@ -164,5 +245,8 @@ def show( arrays.append(row_arrays) titles.append(row_titles) - result = show_2d(arrays, scalebar=scalebar, title=titles, **kwargs) - return result if returnfig else None + fig, axes = show_2d(arrays, scalebar=scalebar, title=titles, **kwargs) + if suptitle is not None: + fig.suptitle(suptitle, fontsize=14) + fig.subplots_adjust(top=0.92) + return (fig, axes) if returnfig else None From 203d60aa9fc0a5fea30c5ddef214408da9fe3e0a Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Fri, 16 Jan 2026 22:53:02 -0800 Subject: [PATCH 03/12] Add Dataset5dstem class for 5D-STEM data with stack operations, slicing, and virtual imaging --- src/quantem/core/datastructures/__init__.py | 8 +- .../core/datastructures/dataset4dstem.py | 59 +-- src/quantem/core/datastructures/dataset5d.py | 180 ++++++++ .../core/datastructures/dataset5dstem.py | 431 ++++++++++++++++++ src/quantem/core/io/__init__.py | 1 + src/quantem/core/io/file_readers.py | 163 +++++++ src/quantem/core/utils/masks.py | 64 +++ tests/datastructures/test_dataset5dstem.py | 191 ++++++++ 8 files changed, 1039 insertions(+), 58 deletions(-) create mode 100644 src/quantem/core/datastructures/dataset5d.py create mode 100644 src/quantem/core/datastructures/dataset5dstem.py create mode 100644 src/quantem/core/utils/masks.py create mode 100644 tests/datastructures/test_dataset5dstem.py diff --git a/src/quantem/core/datastructures/__init__.py b/src/quantem/core/datastructures/__init__.py index dfb5b47a..4f366b9a 100644 --- a/src/quantem/core/datastructures/__init__.py +++ b/src/quantem/core/datastructures/__init__.py @@ -1,7 +1,9 @@ from quantem.core.datastructures.dataset import Dataset as Dataset from quantem.core.datastructures.vector import Vector as Vector -from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem -from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d -from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d from quantem.core.datastructures.dataset2d import Dataset2d as Dataset2d +from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d +from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d +from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem +from quantem.core.datastructures.dataset5d import Dataset5d as Dataset5d +from quantem.core.datastructures.dataset5dstem import Dataset5dstem as Dataset5dstem diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index 28328636..ed146a7d 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -7,6 +7,7 @@ from quantem.core.datastructures.dataset2d import Dataset2d from quantem.core.datastructures.dataset4d import Dataset4d +from quantem.core.utils.masks import create_annular_mask, create_circle_mask from quantem.core.utils.validators import ensure_valid_array from quantem.core.visualization import show_2d from quantem.core.visualization.visualization_utils import ScalebarConfig @@ -365,6 +366,7 @@ def get_virtual_image( final_mask = mask elif mode is not None and geometry is not None: # Create mask from mode and geometry + dp_shape = self.array.shape[-2:] if mode == "circle": if ( len(geometry) != 2 @@ -373,14 +375,14 @@ def get_virtual_image( ): raise ValueError("For circle mode, geometry must be ((cy, cx), r)") center, radius = geometry - final_mask = self._create_circle_mask(center, radius) + final_mask = create_circle_mask(dp_shape, center, radius) elif mode == "annular": if len(geometry) != 2 or len(geometry[0]) != 2 or len(geometry[1]) != 2: raise ValueError( "For annular mode, geometry must be ((cy, cx), (r_inner, r_outer))" ) center, radii = geometry - final_mask = self._create_annular_mask(center, radii) + final_mask = create_annular_mask(dp_shape, center, radii) else: raise ValueError( f"Unknown mode '{mode}'. Supported modes are 'circle' and 'annular'" @@ -466,59 +468,6 @@ def get_virtual_image( return virtual_image_dataset - def _create_circle_mask(self, center: tuple[float, float], radius: float) -> np.ndarray: - """ - Create a circular mask for virtual image formation. - - Parameters - ---------- - center : tuple[float, float] - Center coordinates (cy, cx) of the circle - radius : float - Radius of the circle - - Returns - ------- - np.ndarray - Boolean mask with True inside the circle - """ - cy, cx = center - dp_shape = self.array.shape[-2:] # Get diffraction pattern dimensions - y, x = np.ogrid[: dp_shape[0], : dp_shape[1]] - - # Calculate distance from center - distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) - - return distance <= radius - - def _create_annular_mask( - self, center: tuple[float, float], radii: tuple[float, float] - ) -> np.ndarray: - """ - Create an annular (ring-shaped) mask for virtual image formation. - - Parameters - ---------- - center : tuple[float, float] - Center coordinates (cy, cx) of the annulus - radii : tuple[float, float] - Inner and outer radii (r_inner, r_outer) of the annulus - - Returns - ------- - np.ndarray - Boolean mask with True inside the annular region - """ - cy, cx = center - r_inner, r_outer = radii - dp_shape = self.array.shape[-2:] # Get diffraction pattern dimensions - y, x = np.ogrid[: dp_shape[0], : dp_shape[1]] - - # Calculate distance from center - distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) - - return (distance >= r_inner) & (distance <= r_outer) - def show_virtual_images(self, figsize: tuple[int, int] | None = None, **kwargs) -> tuple: """ Display all virtual images stored in the dataset using show_2d. diff --git a/src/quantem/core/datastructures/dataset5d.py b/src/quantem/core/datastructures/dataset5d.py new file mode 100644 index 00000000..a33888ea --- /dev/null +++ b/src/quantem/core/datastructures/dataset5d.py @@ -0,0 +1,180 @@ +from typing import Any, Self, Union + +import numpy as np +from numpy.typing import NDArray + +from quantem.core.datastructures.dataset import Dataset +from quantem.core.utils.validators import ensure_valid_array +from quantem.core.visualization.visualization_utils import ScalebarConfig + + +@Dataset.register_dimension(5) +class Dataset5d(Dataset): + """5D dataset class that inherits from Dataset. + + This class represents 5D stacks of data, such as time-series or tilt-series + of 4D-STEM experiments. + + The data consists of a 5D array with dimensions (stack, scan_row, scan_col, k_row, k_col). + The first dimension represents the stack axis (time, tilt, defocus, etc.), + dimensions 1-2 represent real space scanning positions, and dimensions 3-4 + represent reciprocal space diffraction patterns. + + Attributes + ---------- + None beyond base Dataset. + """ + + def __init__( + self, + array: NDArray | Any, + name: str, + origin: NDArray | tuple | list | float | int, + sampling: NDArray | tuple | list | float | int, + units: list[str] | tuple | list, + signal_units: str = "arb. units", + metadata: dict = {}, + _token: object | None = None, + ): + """Initialize a 5D dataset. + + Parameters + ---------- + array : NDArray | Any + The underlying 5D array data + name : str + A descriptive name for the dataset + origin : NDArray | tuple | list | float | int + The origin coordinates for each dimension + sampling : NDArray | tuple | list | float | int + The sampling rate/spacing for each dimension + units : list[str] | tuple | list + Units for each dimension + signal_units : str, optional + Units for the array values, by default "arb. units" + metadata : dict, optional + Additional metadata, by default {} + _token : object | None, optional + Token to prevent direct instantiation, by default None + """ + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + metadata=metadata, + _token=_token, + ) + + @classmethod + def from_array( + cls, + array: NDArray | Any, + name: str | None = None, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + ) -> Self: + """Create a new Dataset5d from an array. + + Parameters + ---------- + array : NDArray | Any + The underlying 5D array data + name : str | None, optional + A descriptive name for the dataset. If None, defaults to "5D dataset" + origin : NDArray | tuple | list | float | int | None, optional + The origin coordinates for each dimension. If None, defaults to zeros + sampling : NDArray | tuple | list | float | int | None, optional + The sampling rate/spacing for each dimension. If None, defaults to ones + units : list[str] | tuple | list | None, optional + Units for each dimension. If None, defaults to ["index", "pixels", "pixels", "pixels", "pixels"] + signal_units : str, optional + Units for the array values, by default "arb. units" + + Returns + ------- + Dataset5d + A new Dataset5d instance + """ + array = ensure_valid_array(array, ndim=5) + return cls( + array=array, + name=name if name is not None else "5D dataset", + origin=origin if origin is not None else np.zeros(5), + sampling=sampling if sampling is not None else np.ones(5), + units=units if units is not None else ["index", "pixels", "pixels", "pixels", "pixels"], + signal_units=signal_units, + _token=cls._token, + ) + + @classmethod + def from_shape( + cls, + shape: tuple[int, int, int, int, int], + name: str = "constant 5D dataset", + fill_value: float = 0.0, + origin: Union[NDArray, tuple, list, float, int] | None = None, + sampling: Union[NDArray, tuple, list, float, int] | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + ) -> Self: + """Create a new Dataset5d filled with a constant value. + + Parameters + ---------- + shape : tuple[int, int, int, int, int] + Shape of the 5D array + name : str, optional + Name for the dataset, by default "constant 5D dataset" + fill_value : float, optional + Value to fill the array with, by default 0.0 + origin : NDArray | tuple | list | float | int | None, optional + Origin coordinates for each dimension + sampling : NDArray | tuple | list | float | int | None, optional + Sampling rate for each dimension + units : list[str] | tuple | list | None, optional + Units for each dimension + signal_units : str, optional + Units for the array values, by default "arb. units" + + Returns + ------- + Dataset5d + A new Dataset5d instance filled with the specified value + """ + array = np.full(shape, fill_value, dtype=np.float32) + return cls.from_array( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + ) + + def show( + self, + index: tuple[int, int, int] = (0, 0, 0), + scalebar: ScalebarConfig | bool = True, + title: str | None = None, + **kwargs, + ): + """ + Display a 2D slice of the 5D dataset. + + Parameters + ---------- + index : tuple[int, int, int] + 3D index of the 2D slice to display (along axes (0, 1, 2)). + scalebar : ScalebarConfig or bool + If True, displays scalebar + title : str + Title of Dataset + **kwargs : dict + Keyword arguments for show_2d + """ + return self[index].show(scalebar=scalebar, title=title, **kwargs) diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py new file mode 100644 index 00000000..b2c4d846 --- /dev/null +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -0,0 +1,431 @@ +"""5D-STEM dataset class for time series, tilt series, and other stacked 4D-STEM data.""" + +from typing import Iterator, Self + +import numpy as np +from numpy.typing import NDArray + +from quantem.core.datastructures.dataset3d import Dataset3d +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.datastructures.dataset5d import Dataset5d +from quantem.core.utils.masks import create_annular_mask, create_circle_mask +from quantem.core.utils.validators import ensure_valid_array + +STACK_TYPES = ("time", "tilt", "energy", "dose", "focus", "generic") + + +class Dataset5dstem(Dataset5d): + """5D-STEM dataset with dimensions (stack, scan_row, scan_col, k_row, k_col). + + The stack axis represents time frames, tilt angles, defocus values, etc. + Dimensions 1-2 are real-space scan positions; dimensions 3-4 are reciprocal-space + diffraction patterns. + + Parameters + ---------- + stack_type : str + Type of stack: "time", "tilt", "energy", "dose", "focus", or "generic". + stack_values : NDArray | None + Explicit values for the stack dimension (e.g., timestamps, tilt angles). + + Examples + -------- + >>> data = read_5dstem("path/to/file.h5") + >>> len(data) # number of frames + 10 + >>> frame = data[0] # get first frame as Dataset4dstem + >>> mean_4d = data.stack_mean() # average over stack -> Dataset4dstem + >>> for frame in data: # iterate over frames + ... process(frame) + """ + + def __init__( + self, + array: NDArray, + name: str, + origin: NDArray, + sampling: NDArray, + units: list[str], + signal_units: str = "arb. units", + metadata: dict | None = None, + stack_type: str = "generic", + stack_values: NDArray | None = None, + _token: object | None = None, + ): + metadata = metadata or {} + for key in ("r_to_q_rotation_cw_deg", "ellipticity"): + metadata.setdefault(key, None) + + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + metadata=metadata, + _token=_token, + ) + + if stack_type not in STACK_TYPES: + raise ValueError(f"stack_type must be one of {STACK_TYPES}, got '{stack_type}'") + + self._stack_type = stack_type + + if stack_values is not None: + stack_values = np.asarray(stack_values) + if len(stack_values) != self.shape[0]: + raise ValueError( + f"stack_values length ({len(stack_values)}) must match " + f"number of frames ({self.shape[0]})" + ) + self._stack_values = stack_values + self._virtual_images: dict[str, Dataset3d] = {} + self._virtual_detectors: dict[str, dict] = {} + + def __repr__(self) -> str: + return ( + f"Dataset5dstem(shape={self.shape}, dtype={self.array.dtype}, " + f"stack_type='{self._stack_type}')" + ) + + def __str__(self) -> str: + return ( + f"Dataset5dstem '{self.name}'\n" + f" shape: {self.shape} ({len(self)} frames)\n" + f" stack_type: '{self._stack_type}'\n" + f" scan sampling: {self.sampling[1:3]} {self.units[1:3]}\n" + f" k sampling: {self.sampling[3:]} {self.units[3:]}" + ) + + def __len__(self) -> int: + return self.shape[0] + + def __iter__(self) -> Iterator[Dataset4dstem]: + for i in range(len(self)): + yield self._get_frame(i) + + def __getitem__(self, idx) -> "Dataset4dstem | Dataset5dstem": + if isinstance(idx, int): + return self._get_frame(idx) + + # Handle tuple where first element is int (e.g., data[0, ...]) + if isinstance(idx, tuple) and len(idx) > 0 and isinstance(idx[0], int): + return self._get_frame(idx[0])[idx[1:]] + + # Slicing returns Dataset5dstem with preserved stack_type + if isinstance(idx, slice): + sliced_array = self.array[idx] + sliced_values = self._stack_values[idx] if self._stack_values is not None else None + return self.from_array( + array=sliced_array, + name=self.name, + origin=self.origin, + sampling=self.sampling, + units=self.units, + signal_units=self.signal_units, + stack_type=self._stack_type, + stack_values=sliced_values, + ) + + return super().__getitem__(idx) + + @property + def stack_type(self) -> str: + """Type of stack dimension: 'time', 'tilt', 'energy', 'dose', 'focus', or 'generic'.""" + return self._stack_type + + @property + def stack_values(self) -> NDArray | None: + """Explicit values for the stack dimension, or None if using indices.""" + return self._stack_values + + @property + def virtual_images(self) -> dict[str, Dataset3d]: + """Cached virtual image stacks, keyed by name.""" + return self._virtual_images + + @property + def virtual_detectors(self) -> dict[str, dict]: + """Virtual detector configurations for regenerating images.""" + return self._virtual_detectors + + # ------------------------------------------------------------------------- + # Construction + # ------------------------------------------------------------------------- + + @classmethod + def from_array( + cls, + array: NDArray, + name: str | None = None, + origin: NDArray | tuple | list | None = None, + sampling: NDArray | tuple | list | None = None, + units: list[str] | None = None, + signal_units: str = "arb. units", + stack_type: str = "generic", + stack_values: NDArray | None = None, + ) -> Self: + """Create Dataset5dstem from a 5D array. + + Parameters + ---------- + array : NDArray + 5D array with shape (stack, scan_row, scan_col, k_row, k_col). + name : str, optional + Dataset name. Default: "5D-STEM dataset". + origin : array-like, optional + Origin for each dimension (4 or 5 elements). Default: zeros. + sampling : array-like, optional + Sampling for each dimension (4 or 5 elements). Default: ones. + units : list[str], optional + Units for each dimension (4 or 5 elements). Default: ["pixels", ...]. + signal_units : str, optional + Units for intensity values. Default: "arb. units". + stack_type : str, optional + Type of stack dimension. Default: "generic". + stack_values : NDArray, optional + Explicit values for stack positions (e.g., times, angles). + + Returns + ------- + Dataset5dstem + """ + array = ensure_valid_array(array, ndim=5) + + # Accept 4-element inputs (scan + k dims); prepend stack defaults + def expand_to_5d(arr, default): + if arr is None: + return default + arr = np.asarray(arr) + if arr.size == 4: + return np.concatenate([[default[0]], arr]) + return arr + + origin_5d = expand_to_5d(origin, np.zeros(5)) + sampling_5d = expand_to_5d(sampling, np.ones(5)) + + if units is None: + units_5d = ["pixels"] * 5 + elif len(units) == 4: + units_5d = ["index"] + list(units) + else: + units_5d = list(units) + + return cls( + array=array, + name=name or "5D-STEM dataset", + origin=origin_5d, + sampling=sampling_5d, + units=units_5d, + signal_units=signal_units, + stack_type=stack_type, + stack_values=stack_values, + _token=cls._token, + ) + + @classmethod + def from_4dstem( + cls, + datasets: list[Dataset4dstem], + stack_type: str = "generic", + stack_values: NDArray | None = None, + name: str | None = None, + ) -> Self: + """Create Dataset5dstem by stacking multiple Dataset4dstem objects. + + Parameters + ---------- + datasets : list[Dataset4dstem] + List of 4D-STEM datasets to stack. Must have identical shapes. + stack_type : str, optional + Type of stack dimension. Default: "generic". + stack_values : NDArray, optional + Explicit values for stack positions. + name : str, optional + Dataset name. + + Returns + ------- + Dataset5dstem + """ + if not datasets: + raise ValueError("datasets list cannot be empty") + + first = datasets[0] + + # Validate consistency across all datasets + for i, ds in enumerate(datasets[1:], start=1): + if ds.shape != first.shape: + raise ValueError( + f"Dataset {i} shape {ds.shape} doesn't match first dataset shape {first.shape}" + ) + if not np.allclose(ds.sampling, first.sampling): + raise ValueError( + f"Dataset {i} sampling {ds.sampling} doesn't match first dataset" + ) + if ds.units != first.units: + raise ValueError( + f"Dataset {i} units {ds.units} doesn't match first dataset" + ) + + stacked = np.stack([d.array for d in datasets], axis=0) + + return cls.from_array( + array=stacked, + name=name or "5D-STEM dataset", + origin=np.concatenate([[0], first.origin]), + sampling=np.concatenate([[1], first.sampling]), + units=["index"] + list(first.units), + signal_units=first.signal_units, + stack_type=stack_type, + stack_values=stack_values, + ) + + # ------------------------------------------------------------------------- + # Stack operations + # ------------------------------------------------------------------------- + + def stack_mean(self) -> Dataset4dstem: + """Average over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.mean, "mean") + + def stack_sum(self) -> Dataset4dstem: + """Sum over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.sum, "sum") + + def stack_max(self) -> Dataset4dstem: + """Maximum over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.max, "max") + + def stack_min(self) -> Dataset4dstem: + """Minimum over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.min, "min") + + def _reduce_stack(self, func, suffix: str) -> Dataset4dstem: + """Apply reduction function over stack axis.""" + return Dataset4dstem.from_array( + array=func(self.array, axis=0), + name=f"{self.name}_{suffix}", + origin=self.origin[1:], + sampling=self.sampling[1:], + units=self.units[1:], + signal_units=self.signal_units, + ) + + def _get_frame(self, idx: int) -> Dataset4dstem: + """Extract a single 4D frame from the stack.""" + if idx < 0: + idx = len(self) + idx + if idx < 0 or idx >= len(self): + raise IndexError(f"Frame index {idx} out of range for {len(self)} frames") + + frame = Dataset4dstem.from_array( + array=self.array[idx], + name=f"{self.name}_frame{idx}", + origin=self.origin[1:], + sampling=self.sampling[1:], + units=self.units[1:], + signal_units=self.signal_units, + ) + + # Inherit STEM metadata + frame.metadata["r_to_q_rotation_cw_deg"] = self.metadata.get("r_to_q_rotation_cw_deg") + frame.metadata["ellipticity"] = self.metadata.get("ellipticity") + + # Inherit virtual detector definitions + for name, info in self._virtual_detectors.items(): + frame._virtual_detectors[name] = { + "mask": None, + "mode": info["mode"], + "geometry": info["geometry"], + } + + return frame + + # ------------------------------------------------------------------------- + # Virtual imaging + # ------------------------------------------------------------------------- + + def get_virtual_image( + self, + mask: np.ndarray | None = None, + mode: str | None = None, + geometry: tuple | None = None, + name: str = "virtual_image", + attach: bool = True, + ) -> Dataset3d: + """Compute virtual image stack for all frames. + + Parameters + ---------- + mask : np.ndarray, optional + Custom mask matching diffraction pattern shape. + mode : str, optional + Mask mode: "circle" or "annular". + geometry : tuple, optional + For "circle": ((cy, cx), radius). + For "annular": ((cy, cx), (r_inner, r_outer)). + name : str, optional + Name for the virtual image. Default: "virtual_image". + attach : bool, optional + Store result in virtual_images dict. Default: True. + + Returns + ------- + Dataset3d + Virtual image stack with shape (n_frames, scan_row, scan_col). + """ + dp_shape = self.array.shape[-2:] + + if mask is not None: + if mask.shape != dp_shape: + raise ValueError(f"Mask shape {mask.shape} != diffraction pattern shape {dp_shape}") + final_mask = mask + elif mode and geometry: + if mode == "circle": + center, radius = geometry + final_mask = create_circle_mask(dp_shape, center, radius) + elif mode == "annular": + center, radii = geometry + final_mask = create_annular_mask(dp_shape, center, radii) + else: + raise ValueError(f"Unknown mode '{mode}'. Use 'circle' or 'annular'.") + else: + raise ValueError("Provide either mask or both mode and geometry") + + virtual_stack = np.sum(self.array * final_mask, axis=(-1, -2)) + + vi = Dataset3d.from_array( + array=virtual_stack, + name=name, + origin=self.origin[:3], + sampling=self.sampling[:3], + units=self.units[:3], + signal_units=self.signal_units, + ) + + if attach: + self._virtual_images[name] = vi + self._virtual_detectors[name] = { + "mask": final_mask.copy() if mask is not None else None, + "mode": mode, + "geometry": geometry, + } + + return vi + + # ------------------------------------------------------------------------- + # Copy + # ------------------------------------------------------------------------- + + def _copy_custom_attributes(self, new_dataset) -> None: + """Copy Dataset5dstem-specific attributes.""" + super()._copy_custom_attributes(new_dataset) + new_dataset._stack_type = self._stack_type + new_dataset._stack_values = self._stack_values.copy() if self._stack_values is not None else None + new_dataset._virtual_images = {} + new_dataset._virtual_detectors = { + name: {"mask": None, "mode": info["mode"], "geometry": info["geometry"]} + for name, info in self._virtual_detectors.items() + } diff --git a/src/quantem/core/io/__init__.py b/src/quantem/core/io/__init__.py index 2780eae4..c6fa56a0 100644 --- a/src/quantem/core/io/__init__.py +++ b/src/quantem/core/io/__init__.py @@ -1,5 +1,6 @@ from quantem.core.io.file_readers import read_2d as read_2d from quantem.core.io.file_readers import read_4dstem as read_4dstem +from quantem.core.io.file_readers import read_5dstem as read_5dstem from quantem.core.io.file_readers import ( read_emdfile_to_4dstem as read_emdfile_to_4dstem, ) diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index cb36f1de..22c140ba 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -1,13 +1,16 @@ import importlib +import json from os import PathLike from pathlib import Path import h5py +import numpy as np from quantem.core.datastructures import Dataset as Dataset from quantem.core.datastructures import Dataset2d as Dataset2d from quantem.core.datastructures import Dataset3d as Dataset3d from quantem.core.datastructures import Dataset4dstem as Dataset4dstem +from quantem.core.datastructures import Dataset5dstem as Dataset5dstem def read_4dstem( @@ -93,6 +96,166 @@ def read_4dstem( return dataset +def read_5dstem( + file_path: str | PathLike, + file_type: str | None = None, + stack_type: str = "auto", + **kwargs, +) -> Dataset5dstem: + """ + File reader for 5D-STEM data. + + Supports: + - Nion Swift h5 files (auto-detected from 'properties' attribute) + - rosettasciio formats with 5D data + + Parameters + ---------- + file_path : str | PathLike + Path to data + file_type : str | None, optional + The type of file reader needed. If None, auto-detect. + stack_type : str, optional + Stack type ("sequence", "tilt", etc.) or "auto" to detect from metadata. + **kwargs : dict + Additional keyword arguments to pass to Dataset5dstem constructor. + + Returns + ------- + Dataset5dstem + """ + file_path = Path(file_path) + + # Try Nion Swift h5 format first + if file_path.suffix.lower() in [".h5", ".hdf5"]: + try: + with h5py.File(file_path, "r") as f: + if "data" in f and "properties" in f["data"].attrs: + # Nion Swift format detected + return _read_nion_swift_5dstem(file_path, stack_type, **kwargs) + except Exception: + pass # Fall through to rsciio + + # Fall back to rosettasciio + if file_type is None: + file_type = file_path.suffix.lower().lstrip(".") + + file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader + data_list = file_reader(file_path) + + # Find first 5D dataset + five_d_datasets = [(i, d) for i, d in enumerate(data_list) if d["data"].ndim == 5] + + if len(five_d_datasets) == 0: + print(f"No 5D datasets found in {file_path}. Available datasets:") + for i, d in enumerate(data_list): + print(f" Dataset {i}: shape {d['data'].shape}, ndim={d['data'].ndim}") + raise ValueError("No 5D dataset found in file") + + dataset_index, imported_data = five_d_datasets[0] + + if len(data_list) > 1: + print( + f"File contains {len(data_list)} dataset(s). Using dataset {dataset_index} " + f"with shape {imported_data['data'].shape}" + ) + + imported_axes = imported_data["axes"] + + sampling = kwargs.pop( + "sampling", + [ax["scale"] for ax in imported_axes], + ) + origin = kwargs.pop( + "origin", + [ax["offset"] for ax in imported_axes], + ) + units = kwargs.pop( + "units", + ["pixels" if ax["units"] == "1" else ax["units"] for ax in imported_axes], + ) + + # Determine stack type + if stack_type == "auto": + stack_type = "generic" + + dataset = Dataset5dstem.from_array( + array=imported_data["data"], + sampling=sampling, + origin=origin, + units=units, + stack_type=stack_type, + **kwargs, + ) + + return dataset + + +def _read_nion_swift_5dstem( + file_path: str | PathLike, + stack_type: str = "auto", + **kwargs, +) -> Dataset5dstem: + """ + Read Nion Swift 5D-STEM h5 file. + + Nion Swift stores data with: + - f['data'] containing the array + - f['data'].attrs['properties'] containing JSON metadata + + Parameters + ---------- + file_path : str | PathLike + Path to Nion Swift h5 file + stack_type : str, optional + Stack type or "auto" to detect from metadata + + Returns + ------- + Dataset5dstem + """ + with h5py.File(file_path, "r") as f: + data = f["data"][:] + props = json.loads(f["data"].attrs["properties"]) + + if data.ndim != 5: + raise ValueError(f"Expected 5D data, got {data.ndim}D with shape {data.shape}") + + # Extract calibrations + cals = props.get("dimensional_calibrations", []) + if len(cals) == 5: + origin = np.array([c.get("offset", 0.0) for c in cals]) + sampling = np.array([c.get("scale", 1.0) for c in cals]) + units = [c.get("units", "") or "pixels" for c in cals] + else: + origin = np.zeros(5) + sampling = np.ones(5) + units = ["pixels"] * 5 + + # Determine stack type from metadata + if stack_type == "auto": + if props.get("is_sequence", False): + stack_type = "time" + else: + stack_type = "generic" + + # Get intensity calibration + intensity_cal = props.get("intensity_calibration", {}) + signal_units = intensity_cal.get("units", "arb. units") or "arb. units" + + dataset = Dataset5dstem.from_array( + array=data, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + stack_type=stack_type, + **kwargs, + ) + + return dataset + + def read_2d( file_path: str | PathLike, file_type: str | None = None, diff --git a/src/quantem/core/utils/masks.py b/src/quantem/core/utils/masks.py new file mode 100644 index 00000000..0b884a6b --- /dev/null +++ b/src/quantem/core/utils/masks.py @@ -0,0 +1,64 @@ +import numpy as np + + +def create_circle_mask( + shape: tuple[int, int], + center: tuple[float, float], + radius: float, +) -> np.ndarray: + """ + Create a circular mask for virtual image formation. + + Parameters + ---------- + shape : tuple[int, int] + Shape of the mask (rows, cols) + center : tuple[float, float] + Center coordinates (cy, cx) of the circle + radius : float + Radius of the circle + + Returns + ------- + np.ndarray + Boolean mask with True inside the circle + """ + cy, cx = center + y, x = np.ogrid[: shape[0], : shape[1]] + + # Calculate distance from center + distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) + + return distance <= radius + + +def create_annular_mask( + shape: tuple[int, int], + center: tuple[float, float], + radii: tuple[float, float], +) -> np.ndarray: + """ + Create an annular (ring-shaped) mask for virtual image formation. + + Parameters + ---------- + shape : tuple[int, int] + Shape of the mask (rows, cols) + center : tuple[float, float] + Center coordinates (cy, cx) of the annulus + radii : tuple[float, float] + Inner and outer radii (r_inner, r_outer) of the annulus + + Returns + ------- + np.ndarray + Boolean mask with True inside the annular region + """ + cy, cx = center + r_inner, r_outer = radii + y, x = np.ogrid[: shape[0], : shape[1]] + + # Calculate distance from center + distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) + + return (distance >= r_inner) & (distance <= r_outer) diff --git a/tests/datastructures/test_dataset5dstem.py b/tests/datastructures/test_dataset5dstem.py new file mode 100644 index 00000000..ef85e17a --- /dev/null +++ b/tests/datastructures/test_dataset5dstem.py @@ -0,0 +1,191 @@ +"""Tests for Dataset5dstem class.""" + +import numpy as np +import pytest + +from quantem.core.datastructures.dataset5dstem import Dataset5dstem +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.datastructures.dataset3d import Dataset3d + + +@pytest.fixture +def sample_dataset(): + """Create a sample 5D-STEM dataset.""" + array = np.random.rand(3, 5, 5, 10, 10) + return Dataset5dstem.from_array(array=array, stack_type="time") + + +class TestDataset5dstem: + """Core Dataset5dstem tests.""" + + def test_from_array(self): + """Test creating Dataset5dstem from array.""" + array = np.random.rand(3, 5, 5, 10, 10) + data = Dataset5dstem.from_array(array=array) + + assert data.shape == (3, 5, 5, 10, 10) + assert data.stack_type == "generic" + assert len(data) == 3 + + def test_from_4dstem(self): + """Test creating Dataset5dstem from list of Dataset4dstem.""" + datasets = [ + Dataset4dstem.from_array(array=np.random.rand(5, 5, 10, 10)) + for _ in range(3) + ] + data = Dataset5dstem.from_4dstem(datasets, stack_type="tilt") + + assert data.shape == (3, 5, 5, 10, 10) + assert data.stack_type == "tilt" + + def test_indexing(self, sample_dataset): + """Test data[i] returns Dataset4dstem.""" + frame = sample_dataset[0] + + assert isinstance(frame, Dataset4dstem) + assert frame.shape == (5, 5, 10, 10) + assert np.array_equal(frame.array, sample_dataset.array[0]) + + def test_iteration(self, sample_dataset): + """Test iteration over frames.""" + frames = list(sample_dataset) + + assert len(frames) == 3 + assert all(isinstance(f, Dataset4dstem) for f in frames) + + def test_stack_mean(self, sample_dataset): + """Test stack_mean returns Dataset4dstem.""" + mean = sample_dataset.stack_mean() + + assert isinstance(mean, Dataset4dstem) + assert mean.shape == (5, 5, 10, 10) + assert np.allclose(mean.array, np.mean(sample_dataset.array, axis=0)) + + def test_stack_min(self, sample_dataset): + """Test stack_min returns Dataset4dstem.""" + minimum = sample_dataset.stack_min() + + assert isinstance(minimum, Dataset4dstem) + assert minimum.shape == (5, 5, 10, 10) + assert np.allclose(minimum.array, np.min(sample_dataset.array, axis=0)) + + def test_stack_sum(self, sample_dataset): + """Test stack_sum returns Dataset4dstem.""" + total = sample_dataset.stack_sum() + + assert isinstance(total, Dataset4dstem) + assert total.shape == (5, 5, 10, 10) + + def test_stack_max(self, sample_dataset): + """Test stack_max returns Dataset4dstem.""" + maximum = sample_dataset.stack_max() + + assert isinstance(maximum, Dataset4dstem) + assert maximum.shape == (5, 5, 10, 10) + + def test_slicing(self, sample_dataset): + """Test data[1:3] returns Dataset5dstem with correct data.""" + sliced = sample_dataset[1:3] + assert isinstance(sliced, Dataset5dstem) + assert sliced.stack_type == "time" + assert np.array_equal(sliced.array, sample_dataset.array[1:3]) + + def test_slicing_ellipsis(self, sample_dataset): + """Test data[1:3, ...] returns Dataset5dstem with correct data.""" + sliced = sample_dataset[1:3, ...] + assert isinstance(sliced, Dataset5dstem) + assert np.array_equal(sliced.array, sample_dataset.array[1:3, ...]) + + def test_slicing_scan_position(self, sample_dataset): + """Test data[:, 2, 2] returns Dataset3d with correct data.""" + sliced = sample_dataset[:, 2, 2] + assert isinstance(sliced, Dataset3d) + assert np.array_equal(sliced.array, sample_dataset.array[:, 2, 2]) + + def test_slicing_k_roi(self, sample_dataset): + """Test data[:, :, :, 2:8, 2:8] returns Dataset5dstem with correct data.""" + sliced = sample_dataset[:, :, :, 2:8, 2:8] + assert isinstance(sliced, Dataset5dstem) + assert np.array_equal(sliced.array, sample_dataset.array[:, :, :, 2:8, 2:8]) + + def test_slicing_frame_ellipsis(self, sample_dataset): + """Test data[0, ...] same as data[0] with correct data.""" + sliced = sample_dataset[0, ...] + assert isinstance(sliced, Dataset4dstem) + assert np.array_equal(sliced.array, sample_dataset.array[0, ...]) + + def test_slicing_last_axis(self, sample_dataset): + """Test data[..., 0] slices last axis with correct data.""" + sliced = sample_dataset[..., 0] + assert np.array_equal(sliced.array, sample_dataset.array[..., 0]) + + def test_slicing_scan_k_roi(self, sample_dataset): + """Test data[:, 2, 2, 2:8, 2:8] with correct data.""" + sliced = sample_dataset[:, 2, 2, 2:8, 2:8] + assert np.array_equal(sliced.array, sample_dataset.array[:, 2, 2, 2:8, 2:8]) + + def test_slicing_substack_k_roi(self, sample_dataset): + """Test data[1:3, :, :, 2:8, 2:8] with correct data.""" + sliced = sample_dataset[1:3, :, :, 2:8, 2:8] + assert isinstance(sliced, Dataset5dstem) + assert np.array_equal(sliced.array, sample_dataset.array[1:3, :, :, 2:8, 2:8]) + + def test_bin_all(self): + """Test bin(2) bins all dimensions including stack.""" + array = np.random.rand(4, 6, 6, 10, 10) + data = Dataset5dstem.from_array(array=array) + + binned = data.bin(2) + + assert binned.shape == (2, 3, 3, 5, 5) # all dims halved + + def test_bin_preserve_stack(self): + """Test bin with axes preserves stack.""" + array = np.random.rand(4, 6, 6, 10, 10) + data = Dataset5dstem.from_array(array=array) + + binned = data.bin(2, axes=(1, 2, 3, 4)) + + assert binned.shape == (4, 3, 3, 5, 5) # stack preserved + + def test_get_virtual_image(self, sample_dataset): + """Test virtual image creation.""" + vi = sample_dataset.get_virtual_image( + mode="circle", + geometry=((5, 5), 3), + name="bf", + ) + + assert isinstance(vi, Dataset3d) + assert vi.shape == (3, 5, 5) + assert "bf" in sample_dataset.virtual_images + + def test_repr(self, sample_dataset): + """Test __repr__ shows Dataset5dstem.""" + r = repr(sample_dataset) + + assert "Dataset5dstem" in r + assert "stack_type='time'" in r + + def test_str(self, sample_dataset): + """Test __str__ shows formatted output.""" + s = str(sample_dataset) + + assert "Dataset5dstem" in s + assert "3 frames" in s + assert "stack_type: 'time'" in s + + def test_stack_values_validation(self): + """Test stack_values length must match number of frames.""" + array = np.random.rand(3, 5, 5, 10, 10) + + with pytest.raises(ValueError, match="stack_values length"): + Dataset5dstem.from_array(array=array, stack_values=np.array([1, 2])) # wrong length + + def test_from_4dstem_validation(self): + """Test from_4dstem validates consistent shapes.""" + ds1 = Dataset4dstem.from_array(array=np.random.rand(5, 5, 10, 10)) + ds2 = Dataset4dstem.from_array(array=np.random.rand(6, 6, 10, 10)) # different shape + + with pytest.raises(ValueError, match="shape"): + Dataset5dstem.from_4dstem([ds1, ds2]) From fee7e5e7d6f13d7965b7ed3912058ec76136331a Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Fri, 16 Jan 2026 23:16:57 -0800 Subject: [PATCH 04/12] Improve Dataset5dstem: ArrayLike inputs, np.integer support, torch GPU error, slicing metadata fix --- src/quantem/core/datastructures/dataset5d.py | 6 +- .../core/datastructures/dataset5dstem.py | 122 +++++++++++------- src/quantem/core/utils/validators.py | 8 ++ tests/datastructures/test_dataset5dstem.py | 28 ++-- 4 files changed, 103 insertions(+), 61 deletions(-) diff --git a/src/quantem/core/datastructures/dataset5d.py b/src/quantem/core/datastructures/dataset5d.py index a33888ea..62bc6554 100644 --- a/src/quantem/core/datastructures/dataset5d.py +++ b/src/quantem/core/datastructures/dataset5d.py @@ -1,4 +1,4 @@ -from typing import Any, Self, Union +from typing import Any, Self import numpy as np from numpy.typing import NDArray @@ -117,8 +117,8 @@ def from_shape( shape: tuple[int, int, int, int, int], name: str = "constant 5D dataset", fill_value: float = 0.0, - origin: Union[NDArray, tuple, list, float, int] | None = None, - sampling: Union[NDArray, tuple, list, float, int] | None = None, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", ) -> Self: diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py index b2c4d846..74fa7e2d 100644 --- a/src/quantem/core/datastructures/dataset5dstem.py +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -3,7 +3,7 @@ from typing import Iterator, Self import numpy as np -from numpy.typing import NDArray +from numpy.typing import ArrayLike, NDArray from quantem.core.datastructures.dataset3d import Dataset3d from quantem.core.datastructures.dataset4dstem import Dataset4dstem @@ -41,15 +41,15 @@ class Dataset5dstem(Dataset5d): def __init__( self, - array: NDArray, + array: ArrayLike, name: str, - origin: NDArray, - sampling: NDArray, + origin: ArrayLike, + sampling: ArrayLike, units: list[str], signal_units: str = "arb. units", metadata: dict | None = None, stack_type: str = "generic", - stack_values: NDArray | None = None, + stack_values: ArrayLike | None = None, _token: object | None = None, ): metadata = metadata or {} @@ -106,29 +106,66 @@ def __iter__(self) -> Iterator[Dataset4dstem]: yield self._get_frame(i) def __getitem__(self, idx) -> "Dataset4dstem | Dataset5dstem": - if isinstance(idx, int): - return self._get_frame(idx) + # Handle integer indexing (including numpy integers) + if isinstance(idx, (int, np.integer)): + return self._get_frame(int(idx)) # Handle tuple where first element is int (e.g., data[0, ...]) - if isinstance(idx, tuple) and len(idx) > 0 and isinstance(idx[0], int): - return self._get_frame(idx[0])[idx[1:]] + if isinstance(idx, tuple) and len(idx) > 0 and isinstance(idx[0], (int, np.integer)): + return self._get_frame(int(idx[0]))[idx[1:]] + + # Reject advanced indexing on stack axis (lists, arrays, boolean masks) + if isinstance(idx, (list, np.ndarray)): + raise TypeError( + "Advanced indexing with lists/arrays on stack axis is not supported. " + "Use integer indexing or slices instead." + ) + if isinstance(idx, tuple) and len(idx) > 0 and isinstance(idx[0], (list, np.ndarray)): + raise TypeError( + "Advanced indexing with lists/arrays on stack axis is not supported. " + "Use integer indexing or slices instead." + ) - # Slicing returns Dataset5dstem with preserved stack_type - if isinstance(idx, slice): - sliced_array = self.array[idx] - sliced_values = self._stack_values[idx] if self._stack_values is not None else None + # Get result from base class slicing + result = super().__getitem__(idx) + + # If result is still 5D, wrap back into Dataset5dstem with preserved metadata + if result.array.ndim == 5: + # Figure out how stack_values should be sliced + sliced_values = self._slice_stack_values(idx) return self.from_array( - array=sliced_array, - name=self.name, - origin=self.origin, - sampling=self.sampling, - units=self.units, - signal_units=self.signal_units, + array=result.array, + name=result.name, + origin=result.origin, + sampling=result.sampling, + units=result.units, + signal_units=result.signal_units, stack_type=self._stack_type, stack_values=sliced_values, ) - return super().__getitem__(idx) + return result + + def _slice_stack_values(self, idx): + """Slice stack_values based on how the stack axis is indexed.""" + if self._stack_values is None: + return None + + # Simple slice on first axis + if isinstance(idx, slice): + return self._stack_values[idx] + + # Tuple indexing - check first element + if isinstance(idx, tuple) and len(idx) > 0: + first = idx[0] + if isinstance(first, slice): + return self._stack_values[first] + if first is Ellipsis: + # Ellipsis at start means stack axis not sliced + return self._stack_values + + # Default: preserve all stack_values + return self._stack_values @property def stack_type(self) -> str: @@ -157,34 +194,34 @@ def virtual_detectors(self) -> dict[str, dict]: @classmethod def from_array( cls, - array: NDArray, + array: ArrayLike, name: str | None = None, - origin: NDArray | tuple | list | None = None, - sampling: NDArray | tuple | list | None = None, + origin: ArrayLike | None = None, + sampling: ArrayLike | None = None, units: list[str] | None = None, signal_units: str = "arb. units", stack_type: str = "generic", - stack_values: NDArray | None = None, + stack_values: ArrayLike | None = None, ) -> Self: """Create Dataset5dstem from a 5D array. Parameters ---------- - array : NDArray + array : array-like 5D array with shape (stack, scan_row, scan_col, k_row, k_col). name : str, optional Dataset name. Default: "5D-STEM dataset". origin : array-like, optional - Origin for each dimension (4 or 5 elements). Default: zeros. + Origin for each dimension (5 elements). Default: zeros. sampling : array-like, optional - Sampling for each dimension (4 or 5 elements). Default: ones. + Sampling for each dimension (5 elements). Default: ones. units : list[str], optional - Units for each dimension (4 or 5 elements). Default: ["pixels", ...]. + Units for each dimension (5 elements). Default: ["pixels", ...]. signal_units : str, optional Units for intensity values. Default: "arb. units". stack_type : str, optional Type of stack dimension. Default: "generic". - stack_values : NDArray, optional + stack_values : array-like, optional Explicit values for stack positions (e.g., times, angles). Returns @@ -193,24 +230,9 @@ def from_array( """ array = ensure_valid_array(array, ndim=5) - # Accept 4-element inputs (scan + k dims); prepend stack defaults - def expand_to_5d(arr, default): - if arr is None: - return default - arr = np.asarray(arr) - if arr.size == 4: - return np.concatenate([[default[0]], arr]) - return arr - - origin_5d = expand_to_5d(origin, np.zeros(5)) - sampling_5d = expand_to_5d(sampling, np.ones(5)) - - if units is None: - units_5d = ["pixels"] * 5 - elif len(units) == 4: - units_5d = ["index"] + list(units) - else: - units_5d = list(units) + origin_5d = np.zeros(5) if origin is None else np.asarray(origin) + sampling_5d = np.ones(5) if sampling is None else np.asarray(sampling) + units_5d = ["pixels"] * 5 if units is None else list(units) return cls( array=array, @@ -302,6 +324,10 @@ def stack_min(self) -> Dataset4dstem: """Minimum over the stack axis. Returns Dataset4dstem.""" return self._reduce_stack(np.min, "min") + def stack_std(self) -> Dataset4dstem: + """Standard deviation over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.std, "std") + def _reduce_stack(self, func, suffix: str) -> Dataset4dstem: """Apply reduction function over stack axis.""" return Dataset4dstem.from_array( @@ -349,7 +375,7 @@ def _get_frame(self, idx: int) -> Dataset4dstem: def get_virtual_image( self, - mask: np.ndarray | None = None, + mask: ArrayLike | None = None, mode: str | None = None, geometry: tuple | None = None, name: str = "virtual_image", diff --git a/src/quantem/core/utils/validators.py b/src/quantem/core/utils/validators.py index 02f8a935..fe57110a 100644 --- a/src/quantem/core/utils/validators.py +++ b/src/quantem/core/utils/validators.py @@ -49,6 +49,14 @@ def ensure_valid_array( TypeError If the input could not be converted to a NumPy array """ + # Check for torch GPU tensors and provide helpful error + if config.get("has_torch"): + if isinstance(array, torch.Tensor) and array.is_cuda: + raise TypeError( + f"Cannot convert torch GPU tensor (device={array.device}) to numpy array. " + "Use .cpu() to move to CPU first: Dataset.from_array(tensor.cpu())" + ) + is_cupy = False if config.get("has_cupy"): if isinstance(array, cp.ndarray): diff --git a/tests/datastructures/test_dataset5dstem.py b/tests/datastructures/test_dataset5dstem.py index ef85e17a..52d6a0ef 100644 --- a/tests/datastructures/test_dataset5dstem.py +++ b/tests/datastructures/test_dataset5dstem.py @@ -40,7 +40,7 @@ def test_from_4dstem(self): def test_indexing(self, sample_dataset): """Test data[i] returns Dataset4dstem.""" - frame = sample_dataset[0] + frame = sample_dataset[0] # -> Dataset4dstem assert isinstance(frame, Dataset4dstem) assert frame.shape == (5, 5, 10, 10) @@ -83,50 +83,58 @@ def test_stack_max(self, sample_dataset): assert isinstance(maximum, Dataset4dstem) assert maximum.shape == (5, 5, 10, 10) + def test_stack_std(self, sample_dataset): + """Test stack_std returns Dataset4dstem.""" + std = sample_dataset.stack_std() + + assert isinstance(std, Dataset4dstem) + assert std.shape == (5, 5, 10, 10) + assert np.allclose(std.array, np.std(sample_dataset.array, axis=0)) + def test_slicing(self, sample_dataset): """Test data[1:3] returns Dataset5dstem with correct data.""" - sliced = sample_dataset[1:3] + sliced = sample_dataset[1:3] # -> Dataset5dstem assert isinstance(sliced, Dataset5dstem) assert sliced.stack_type == "time" assert np.array_equal(sliced.array, sample_dataset.array[1:3]) def test_slicing_ellipsis(self, sample_dataset): """Test data[1:3, ...] returns Dataset5dstem with correct data.""" - sliced = sample_dataset[1:3, ...] + sliced = sample_dataset[1:3, ...] # -> Dataset5dstem assert isinstance(sliced, Dataset5dstem) assert np.array_equal(sliced.array, sample_dataset.array[1:3, ...]) def test_slicing_scan_position(self, sample_dataset): """Test data[:, 2, 2] returns Dataset3d with correct data.""" - sliced = sample_dataset[:, 2, 2] + sliced = sample_dataset[:, 2, 2] # -> Dataset3d assert isinstance(sliced, Dataset3d) assert np.array_equal(sliced.array, sample_dataset.array[:, 2, 2]) def test_slicing_k_roi(self, sample_dataset): """Test data[:, :, :, 2:8, 2:8] returns Dataset5dstem with correct data.""" - sliced = sample_dataset[:, :, :, 2:8, 2:8] + sliced = sample_dataset[:, :, :, 2:8, 2:8] # -> Dataset5dstem assert isinstance(sliced, Dataset5dstem) assert np.array_equal(sliced.array, sample_dataset.array[:, :, :, 2:8, 2:8]) def test_slicing_frame_ellipsis(self, sample_dataset): """Test data[0, ...] same as data[0] with correct data.""" - sliced = sample_dataset[0, ...] + sliced = sample_dataset[0, ...] # -> Dataset4dstem assert isinstance(sliced, Dataset4dstem) assert np.array_equal(sliced.array, sample_dataset.array[0, ...]) def test_slicing_last_axis(self, sample_dataset): """Test data[..., 0] slices last axis with correct data.""" - sliced = sample_dataset[..., 0] + sliced = sample_dataset[..., 0] # -> Dataset4d assert np.array_equal(sliced.array, sample_dataset.array[..., 0]) def test_slicing_scan_k_roi(self, sample_dataset): """Test data[:, 2, 2, 2:8, 2:8] with correct data.""" - sliced = sample_dataset[:, 2, 2, 2:8, 2:8] + sliced = sample_dataset[:, 2, 2, 2:8, 2:8] # -> Dataset3d assert np.array_equal(sliced.array, sample_dataset.array[:, 2, 2, 2:8, 2:8]) def test_slicing_substack_k_roi(self, sample_dataset): """Test data[1:3, :, :, 2:8, 2:8] with correct data.""" - sliced = sample_dataset[1:3, :, :, 2:8, 2:8] + sliced = sample_dataset[1:3, :, :, 2:8, 2:8] # -> Dataset5dstem assert isinstance(sliced, Dataset5dstem) assert np.array_equal(sliced.array, sample_dataset.array[1:3, :, :, 2:8, 2:8]) @@ -150,7 +158,7 @@ def test_bin_preserve_stack(self): def test_get_virtual_image(self, sample_dataset): """Test virtual image creation.""" - vi = sample_dataset.get_virtual_image( + vi = sample_dataset.get_virtual_image( # -> Dataset3d mode="circle", geometry=((5, 5), 3), name="bf", From 5fa22b314e4a4991355d76172704b59ebb9f833e Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Fri, 16 Jan 2026 23:30:25 -0800 Subject: [PATCH 05/12] Add indexing examples to Dataset5dstem docstring and notebook with distinct sizes --- .../core/datastructures/dataset5dstem.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py index 74fa7e2d..d4a2a363 100644 --- a/src/quantem/core/datastructures/dataset5dstem.py +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -30,13 +30,26 @@ class Dataset5dstem(Dataset5d): Examples -------- - >>> data = read_5dstem("path/to/file.h5") - >>> len(data) # number of frames - 10 - >>> frame = data[0] # get first frame as Dataset4dstem - >>> mean_4d = data.stack_mean() # average over stack -> Dataset4dstem - >>> for frame in data: # iterate over frames - ... process(frame) + >>> data = read_5dstem("path/to/file.h5") # shape (4, 6, 7, 3, 5) + >>> len(data) # number of frames -> 4 + + Indexing (integer removes dimension, slice keeps it): + + >>> data[2] # -> Dataset4dstem (6, 7, 3, 5) one frame + >>> data[1:3] # -> Dataset5dstem (2, 6, 7, 3, 5) substack + >>> data[:, 4, 1] # -> Dataset3d (4, 3, 5) one scan position, all frames + >>> data[:, 1:5, 2:6] # -> Dataset5dstem (4, 4, 4, 3, 5) scan region crop + >>> data[..., 0:2, 1:4] # -> Dataset5dstem (4, 6, 7, 2, 3) k-space crop + + Stack operations (reduce over stack axis): + + >>> data.stack_mean() # -> Dataset4dstem + >>> data.stack_std() # -> Dataset4dstem + + Virtual imaging: + + >>> vi = data.get_virtual_image(mode="circle", geometry=((1, 2), 1)) + >>> vi.shape # -> (4, 6, 7) = (frames, scan_row, scan_col) """ def __init__( From e50f72221c8020f14d013449b24b5eb6750add2e Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sat, 17 Jan 2026 00:21:20 -0800 Subject: [PATCH 06/12] Consolidate 5D-STEM tests and add mock Nion Swift fixture --- tests/conftest.py | 61 +++++++ tests/datastructures/test_dataset5dstem.py | 202 +++++---------------- tests/io/test_read_5dstem.py | 44 +++++ 3 files changed, 148 insertions(+), 159 deletions(-) create mode 100644 tests/io/test_read_5dstem.py diff --git a/tests/conftest.py b/tests/conftest.py index 3178a8fd..7cdcf3c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,67 @@ +import json +from pathlib import Path + +import h5py +import numpy as np import pytest +def create_mock_nion_swift_h5(path: Path, shape=(4, 6, 7, 3, 5), is_sequence=True): + """Create a mock Nion Swift h5 file with realistic metadata structure. + + Parameters + ---------- + path : Path + Output file path. + shape : tuple + Data shape (frames, scan_row, scan_col, k_row, k_col). + is_sequence : bool + Whether to set is_sequence=True in properties. + """ + properties = { + "type": "data-item", + "uuid": "test-uuid", + "is_sequence": is_sequence, + "intensity_calibration": {"offset": 0.0, "scale": 1.0, "units": "counts" if is_sequence else ""}, + "dimensional_calibrations": [ + {"offset": 0.0, "scale": 1.0, "units": ""}, # frames + {"offset": -1.0, "scale": 0.5, "units": "nm"}, # scan_row + {"offset": -1.5, "scale": 0.5, "units": "nm"}, # scan_col + {"offset": -0.036, "scale": 0.006, "units": "rad"}, # k_row + {"offset": -0.036, "scale": 0.006, "units": "rad"}, # k_col + ], + "collection_dimension_count": 2, + "datum_dimension_count": 2, + "metadata": { + "instrument": { + "high_tension": 60000.0, # 60 keV + "defocus": 0.0, + "ImageScanned": { + "probe_ha": 0.035, # 35 mrad half-angle + "C10": 0.0, + "C30": 0.0, + }, + }, + "scan": { + "scan_size": [shape[1], shape[2]], + }, + }, + } + + with h5py.File(path, "w") as f: + data = np.random.rand(*shape).astype(np.float32) + dset = f.create_dataset("data", data=data) + dset.attrs["properties"] = json.dumps(properties) + + +@pytest.fixture +def mock_nion_5dstem_file(tmp_path): + """Create a temporary mock Nion Swift 5D-STEM h5 file.""" + path = tmp_path / "mock_nion_5dstem.h5" + create_mock_nion_swift_h5(path, shape=(4, 6, 7, 3, 5)) + return path + + def pytest_addoption(parser): parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") diff --git a/tests/datastructures/test_dataset5dstem.py b/tests/datastructures/test_dataset5dstem.py index 52d6a0ef..6a62a09a 100644 --- a/tests/datastructures/test_dataset5dstem.py +++ b/tests/datastructures/test_dataset5dstem.py @@ -1,8 +1,6 @@ """Tests for Dataset5dstem class.""" - import numpy as np import pytest - from quantem.core.datastructures.dataset5dstem import Dataset5dstem from quantem.core.datastructures.dataset4dstem import Dataset4dstem from quantem.core.datastructures.dataset3d import Dataset3d @@ -10,8 +8,8 @@ @pytest.fixture def sample_dataset(): - """Create a sample 5D-STEM dataset.""" - array = np.random.rand(3, 5, 5, 10, 10) + """Create a sample 5D-STEM dataset with distinct sizes for clarity.""" + array = np.random.rand(4, 6, 7, 3, 5) # (frames, scan_row, scan_col, k_row, k_col) return Dataset5dstem.from_array(array=array, stack_type="time") @@ -20,180 +18,66 @@ class TestDataset5dstem: def test_from_array(self): """Test creating Dataset5dstem from array.""" - array = np.random.rand(3, 5, 5, 10, 10) - data = Dataset5dstem.from_array(array=array) - - assert data.shape == (3, 5, 5, 10, 10) - assert data.stack_type == "generic" - assert len(data) == 3 + array = np.random.rand(4, 6, 7, 3, 5) + data = Dataset5dstem.from_array(array=array, stack_type="tilt") + assert data.shape == (4, 6, 7, 3, 5) + assert data.stack_type == "tilt" + assert len(data) == 4 def test_from_4dstem(self): """Test creating Dataset5dstem from list of Dataset4dstem.""" - datasets = [ - Dataset4dstem.from_array(array=np.random.rand(5, 5, 10, 10)) - for _ in range(3) - ] + datasets = [Dataset4dstem.from_array(np.random.rand(6, 7, 3, 5)) for _ in range(4)] data = Dataset5dstem.from_4dstem(datasets, stack_type="tilt") - - assert data.shape == (3, 5, 5, 10, 10) + assert data.shape == (4, 6, 7, 3, 5) assert data.stack_type == "tilt" - def test_indexing(self, sample_dataset): - """Test data[i] returns Dataset4dstem.""" + def test_indexing_and_iteration(self, sample_dataset): + """Test data[i] and iteration return Dataset4dstem frames.""" frame = sample_dataset[0] # -> Dataset4dstem - assert isinstance(frame, Dataset4dstem) - assert frame.shape == (5, 5, 10, 10) - assert np.array_equal(frame.array, sample_dataset.array[0]) - - def test_iteration(self, sample_dataset): - """Test iteration over frames.""" + assert frame.shape == (6, 7, 3, 5) frames = list(sample_dataset) - - assert len(frames) == 3 + assert len(frames) == 4 assert all(isinstance(f, Dataset4dstem) for f in frames) - def test_stack_mean(self, sample_dataset): - """Test stack_mean returns Dataset4dstem.""" - mean = sample_dataset.stack_mean() - + def test_stack_reductions(self, sample_dataset): + """Test stack reduction methods return Dataset4dstem with correct values.""" + mean = sample_dataset.stack_mean() # -> Dataset4dstem assert isinstance(mean, Dataset4dstem) - assert mean.shape == (5, 5, 10, 10) - assert np.allclose(mean.array, np.mean(sample_dataset.array, axis=0)) - - def test_stack_min(self, sample_dataset): - """Test stack_min returns Dataset4dstem.""" - minimum = sample_dataset.stack_min() - - assert isinstance(minimum, Dataset4dstem) - assert minimum.shape == (5, 5, 10, 10) - assert np.allclose(minimum.array, np.min(sample_dataset.array, axis=0)) - - def test_stack_sum(self, sample_dataset): - """Test stack_sum returns Dataset4dstem.""" - total = sample_dataset.stack_sum() - - assert isinstance(total, Dataset4dstem) - assert total.shape == (5, 5, 10, 10) - - def test_stack_max(self, sample_dataset): - """Test stack_max returns Dataset4dstem.""" - maximum = sample_dataset.stack_max() - - assert isinstance(maximum, Dataset4dstem) - assert maximum.shape == (5, 5, 10, 10) - - def test_stack_std(self, sample_dataset): - """Test stack_std returns Dataset4dstem.""" - std = sample_dataset.stack_std() - - assert isinstance(std, Dataset4dstem) - assert std.shape == (5, 5, 10, 10) - assert np.allclose(std.array, np.std(sample_dataset.array, axis=0)) + assert mean.shape == (6, 7, 3, 5) + arr = sample_dataset.array + assert np.allclose(sample_dataset.stack_mean().array, np.mean(arr, axis=0)) + assert np.allclose(sample_dataset.stack_sum().array, np.sum(arr, axis=0)) + assert np.allclose(sample_dataset.stack_max().array, np.max(arr, axis=0)) + assert np.allclose(sample_dataset.stack_min().array, np.min(arr, axis=0)) + assert np.allclose(sample_dataset.stack_std().array, np.std(arr, axis=0)) def test_slicing(self, sample_dataset): - """Test data[1:3] returns Dataset5dstem with correct data.""" - sliced = sample_dataset[1:3] # -> Dataset5dstem - assert isinstance(sliced, Dataset5dstem) - assert sliced.stack_type == "time" - assert np.array_equal(sliced.array, sample_dataset.array[1:3]) - - def test_slicing_ellipsis(self, sample_dataset): - """Test data[1:3, ...] returns Dataset5dstem with correct data.""" - sliced = sample_dataset[1:3, ...] # -> Dataset5dstem - assert isinstance(sliced, Dataset5dstem) - assert np.array_equal(sliced.array, sample_dataset.array[1:3, ...]) - - def test_slicing_scan_position(self, sample_dataset): - """Test data[:, 2, 2] returns Dataset3d with correct data.""" - sliced = sample_dataset[:, 2, 2] # -> Dataset3d - assert isinstance(sliced, Dataset3d) - assert np.array_equal(sliced.array, sample_dataset.array[:, 2, 2]) - - def test_slicing_k_roi(self, sample_dataset): - """Test data[:, :, :, 2:8, 2:8] returns Dataset5dstem with correct data.""" - sliced = sample_dataset[:, :, :, 2:8, 2:8] # -> Dataset5dstem - assert isinstance(sliced, Dataset5dstem) - assert np.array_equal(sliced.array, sample_dataset.array[:, :, :, 2:8, 2:8]) - - def test_slicing_frame_ellipsis(self, sample_dataset): - """Test data[0, ...] same as data[0] with correct data.""" - sliced = sample_dataset[0, ...] # -> Dataset4dstem - assert isinstance(sliced, Dataset4dstem) - assert np.array_equal(sliced.array, sample_dataset.array[0, ...]) - - def test_slicing_last_axis(self, sample_dataset): - """Test data[..., 0] slices last axis with correct data.""" - sliced = sample_dataset[..., 0] # -> Dataset4d - assert np.array_equal(sliced.array, sample_dataset.array[..., 0]) - - def test_slicing_scan_k_roi(self, sample_dataset): - """Test data[:, 2, 2, 2:8, 2:8] with correct data.""" - sliced = sample_dataset[:, 2, 2, 2:8, 2:8] # -> Dataset3d - assert np.array_equal(sliced.array, sample_dataset.array[:, 2, 2, 2:8, 2:8]) - - def test_slicing_substack_k_roi(self, sample_dataset): - """Test data[1:3, :, :, 2:8, 2:8] with correct data.""" - sliced = sample_dataset[1:3, :, :, 2:8, 2:8] # -> Dataset5dstem - assert isinstance(sliced, Dataset5dstem) - assert np.array_equal(sliced.array, sample_dataset.array[1:3, :, :, 2:8, 2:8]) - - def test_bin_all(self): - """Test bin(2) bins all dimensions including stack.""" - array = np.random.rand(4, 6, 6, 10, 10) - data = Dataset5dstem.from_array(array=array) - - binned = data.bin(2) - - assert binned.shape == (2, 3, 3, 5, 5) # all dims halved - - def test_bin_preserve_stack(self): - """Test bin with axes preserves stack.""" - array = np.random.rand(4, 6, 6, 10, 10) - data = Dataset5dstem.from_array(array=array) - - binned = data.bin(2, axes=(1, 2, 3, 4)) - - assert binned.shape == (4, 3, 3, 5, 5) # stack preserved + """Test common slicing patterns.""" + substack = sample_dataset[1:3] # -> Dataset5dstem (2, 6, 7, 3, 5) + assert isinstance(substack, Dataset5dstem) + assert substack.shape == (2, 6, 7, 3, 5) + assert substack.stack_type == "time" # metadata preserved + position = sample_dataset[:, 2, 3] # -> Dataset3d (4, 3, 5) + assert isinstance(position, Dataset3d) + assert position.shape == (4, 3, 5) + cropped = sample_dataset[:, :, :, 1:3, 1:4] # -> Dataset5dstem (4, 6, 7, 2, 3) + assert isinstance(cropped, Dataset5dstem) + assert cropped.shape == (4, 6, 7, 2, 3) def test_get_virtual_image(self, sample_dataset): - """Test virtual image creation.""" - vi = sample_dataset.get_virtual_image( # -> Dataset3d - mode="circle", - geometry=((5, 5), 3), - name="bf", - ) - + """Test virtual image creation returns Dataset3d stack.""" + vi = sample_dataset.get_virtual_image(mode="circle", geometry=((1, 2), 1), name="bf") assert isinstance(vi, Dataset3d) - assert vi.shape == (3, 5, 5) + assert vi.shape == (4, 6, 7) assert "bf" in sample_dataset.virtual_images - def test_repr(self, sample_dataset): - """Test __repr__ shows Dataset5dstem.""" - r = repr(sample_dataset) - - assert "Dataset5dstem" in r - assert "stack_type='time'" in r - - def test_str(self, sample_dataset): - """Test __str__ shows formatted output.""" - s = str(sample_dataset) - - assert "Dataset5dstem" in s - assert "3 frames" in s - assert "stack_type: 'time'" in s - - def test_stack_values_validation(self): - """Test stack_values length must match number of frames.""" - array = np.random.rand(3, 5, 5, 10, 10) - + def test_validation(self): + """Test validation errors.""" + array = np.random.rand(4, 6, 7, 3, 5) with pytest.raises(ValueError, match="stack_values length"): - Dataset5dstem.from_array(array=array, stack_values=np.array([1, 2])) # wrong length - - def test_from_4dstem_validation(self): - """Test from_4dstem validates consistent shapes.""" - ds1 = Dataset4dstem.from_array(array=np.random.rand(5, 5, 10, 10)) - ds2 = Dataset4dstem.from_array(array=np.random.rand(6, 6, 10, 10)) # different shape - + Dataset5dstem.from_array(array=array, stack_values=np.array([1, 2])) + ds1 = Dataset4dstem.from_array(np.random.rand(6, 7, 3, 5)) + ds2 = Dataset4dstem.from_array(np.random.rand(8, 9, 3, 5)) with pytest.raises(ValueError, match="shape"): Dataset5dstem.from_4dstem([ds1, ds2]) diff --git a/tests/io/test_read_5dstem.py b/tests/io/test_read_5dstem.py new file mode 100644 index 00000000..05e4cfc9 --- /dev/null +++ b/tests/io/test_read_5dstem.py @@ -0,0 +1,44 @@ +"""Tests for read_5dstem with mocked Nion Swift metadata.""" +import json +import numpy as np +import pytest +import h5py +from quantem.core.io import read_5dstem +from quantem.core.datastructures import Dataset5dstem + + +class TestRead5dstemNionSwift: + """Tests for reading Nion Swift 5D-STEM data.""" + + def test_read_nion_5dstem(self, mock_nion_5dstem_file): + """Test reading Nion Swift file extracts shape, type, and calibrations.""" + data = read_5dstem(mock_nion_5dstem_file) + assert isinstance(data, Dataset5dstem) + assert data.shape == (4, 6, 7, 3, 5) + assert data.stack_type == "time" # is_sequence=True + assert np.allclose(data.sampling, [1.0, 0.5, 0.5, 0.006, 0.006]) + assert np.allclose(data.origin[:3], [0.0, -1.0, -1.5]) + assert data.units[1:3] == ["nm", "nm"] + assert data.units[3:] == ["rad", "rad"] + assert data.signal_units == "counts" + + def test_read_override_stack_type(self, mock_nion_5dstem_file): + """Test stack_type can be overridden.""" + data = read_5dstem(mock_nion_5dstem_file, stack_type="tilt") + assert data.stack_type == "tilt" + + def test_read_no_sequence(self, tmp_path): + """Test is_sequence=False -> stack_type='generic', empty units -> 'arb. units'.""" + path = tmp_path / "no_seq.h5" + properties = { + "is_sequence": False, + "dimensional_calibrations": [{"offset": 0, "scale": 1, "units": ""}] * 5, + "intensity_calibration": {"units": ""}, + "metadata": {}, + } + with h5py.File(path, "w") as f: + dset = f.create_dataset("data", data=np.zeros((2, 3, 3, 4, 4), dtype=np.float32)) + dset.attrs["properties"] = json.dumps(properties) + data = read_5dstem(path) + assert data.stack_type == "generic" + assert data.signal_units == "arb. units" From 9738799bd3cf26b95bc1b65b34a3f2782f884a2e Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sat, 17 Jan 2026 00:21:33 -0800 Subject: [PATCH 07/12] Add NumPy-style docstring examples to read_5dstem, from_array, from_4dstem --- .../core/datastructures/dataset5dstem.py | 39 +++++++++++++++++++ src/quantem/core/io/file_readers.py | 33 ++++++++++++++-- 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py index d4a2a363..8401a6cf 100644 --- a/src/quantem/core/datastructures/dataset5dstem.py +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -240,6 +240,26 @@ def from_array( Returns ------- Dataset5dstem + + Examples + -------- + Basic usage: + + >>> import numpy as np + >>> arr = np.random.rand(10, 256, 256, 64, 64) + >>> data = Dataset5dstem.from_array(arr, stack_type="time") + >>> data.shape + (10, 256, 256, 64, 64) + + With calibrations: + + >>> data = Dataset5dstem.from_array( + ... arr, + ... stack_type="tilt", + ... stack_values=np.linspace(-60, 60, 10), # tilt angles in degrees + ... sampling=[1, 0.5, 0.5, 0.01, 0.01], + ... units=["deg", "nm", "nm", "1/nm", "1/nm"], + ... ) """ array = ensure_valid_array(array, ndim=5) @@ -283,6 +303,25 @@ def from_4dstem( Returns ------- Dataset5dstem + + Examples + -------- + Stack multiple 4D-STEM datasets into a tilt series: + + >>> from quantem.core.io import read_4dstem + >>> frames = [read_4dstem(f"tilt_{i:02d}.h5") for i in range(10)] + >>> tilt_series = Dataset5dstem.from_4dstem( + ... frames, + ... stack_type="tilt", + ... stack_values=np.linspace(-60, 60, 10), + ... ) + >>> tilt_series.shape + (10, 256, 256, 128, 128) + + Stack synthetic data: + + >>> datasets = [Dataset4dstem.from_array(np.random.rand(64, 64, 32, 32)) for _ in range(5)] + >>> data = Dataset5dstem.from_4dstem(datasets, stack_type="time") """ if not datasets: raise ValueError("datasets list cannot be empty") diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index 22c140ba..65454118 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -112,17 +112,42 @@ def read_5dstem( Parameters ---------- file_path : str | PathLike - Path to data + Path to data file (.h5, .hdf5, or rosettasciio-supported formats). file_type : str | None, optional - The type of file reader needed. If None, auto-detect. + File reader type (e.g., "hdf5", "emd"). If None, auto-detects from extension. stack_type : str, optional - Stack type ("sequence", "tilt", etc.) or "auto" to detect from metadata. + Type of stack dimension. Options: "time", "tilt", "energy", "dose", "focus", "generic". + Default "auto" detects from metadata (Nion Swift is_sequence=True -> "time"). **kwargs : dict - Additional keyword arguments to pass to Dataset5dstem constructor. + Additional keyword arguments passed to Dataset5dstem constructor. Returns ------- Dataset5dstem + 5D dataset with shape (stack, scan_row, scan_col, k_row, k_col). + + Examples + -------- + Load Nion Swift time series (auto-detects stack_type from is_sequence): + + >>> data = read_5dstem("time_series.h5") + >>> data.shape + (10, 512, 512, 12, 12) + >>> data.stack_type + 'time' + + Load as tilt series (override auto-detection): + + >>> data = read_5dstem("tilt_series.h5", stack_type="tilt") + >>> data.stack_type + 'tilt' + + Access calibrations extracted from file metadata: + + >>> data.sampling # [stack, scan_row, scan_col, k_row, k_col] + array([1.0, 0.5, 0.5, 0.006, 0.006]) + >>> data.units + ['pixels', 'nm', 'nm', 'rad', 'rad'] """ file_path = Path(file_path) From babfa9a8d3ce2f02aabdc7480bed4954d93f7d39 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sat, 17 Jan 2026 00:27:18 -0800 Subject: [PATCH 08/12] Add *.ipynb to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d87d85c8..4dbaef1f 100644 --- a/.gitignore +++ b/.gitignore @@ -191,3 +191,4 @@ ipynb-playground/ # widget (JS build artifacts) node_modules/ widget/src/quantem/widget/static/ +*.ipynb From 550714ad4bf2c0dbb47a78e377aa8891ac15617c Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sun, 18 Jan 2026 11:09:15 -0800 Subject: [PATCH 09/12] Add Dataset5dstem virtual image management and show() error - Add show() that raises NotImplementedError with helpful message - Add show_virtual_images(), regenerate_virtual_images(), update_virtual_detector(), clear_virtual_images(), clear_all_virtual_data() - Add from_file() classmethod - Support auto-fit center per frame with geometry=(None, radius) - Support per-frame geometry with list of geometries - Add __str__ and __repr__ to Dataset4dstem for consistency - Add tests for all new features --- .../core/datastructures/dataset4dstem.py | 11 + .../core/datastructures/dataset5dstem.py | 277 ++++++++++++++---- tests/datastructures/test_dataset5dstem.py | 55 ++++ 3 files changed, 294 insertions(+), 49 deletions(-) diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index ed146a7d..d215b029 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -91,6 +91,17 @@ def __init__( self._virtual_images = {} self._virtual_detectors = {} # Store detector information for regeneration + def __repr__(self) -> str: + return f"Dataset4dstem(shape={self.shape}, dtype={self.array.dtype})" + + def __str__(self) -> str: + return ( + f"Dataset4dstem '{self.name}'\n" + f" shape: {self.shape}\n" + f" scan sampling: {self.sampling[:2]} {self.units[:2]}\n" + f" k sampling: {self.sampling[2:]} {self.units[2:]}" + ) + @classmethod def from_file(cls, file_path: str, file_type: str) -> "Dataset4dstem": """ diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py index 8401a6cf..d67f2676 100644 --- a/src/quantem/core/datastructures/dataset5dstem.py +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -8,8 +8,10 @@ from quantem.core.datastructures.dataset3d import Dataset3d from quantem.core.datastructures.dataset4dstem import Dataset4dstem from quantem.core.datastructures.dataset5d import Dataset5d +from quantem.core.utils.diffractive_imaging_utils import fit_probe_circle from quantem.core.utils.masks import create_annular_mask, create_circle_mask from quantem.core.utils.validators import ensure_valid_array +from quantem.core.visualization import show_2d STACK_TYPES = ("time", "tilt", "energy", "dose", "focus", "generic") @@ -144,8 +146,6 @@ def __getitem__(self, idx) -> "Dataset4dstem | Dataset5dstem": # If result is still 5D, wrap back into Dataset5dstem with preserved metadata if result.array.ndim == 5: - # Figure out how stack_values should be sliced - sliced_values = self._slice_stack_values(idx) return self.from_array( array=result.array, name=result.name, @@ -154,32 +154,11 @@ def __getitem__(self, idx) -> "Dataset4dstem | Dataset5dstem": units=result.units, signal_units=result.signal_units, stack_type=self._stack_type, - stack_values=sliced_values, + stack_values=None, # Don't try to slice stack_values ) return result - def _slice_stack_values(self, idx): - """Slice stack_values based on how the stack axis is indexed.""" - if self._stack_values is None: - return None - - # Simple slice on first axis - if isinstance(idx, slice): - return self._stack_values[idx] - - # Tuple indexing - check first element - if isinstance(idx, tuple) and len(idx) > 0: - first = idx[0] - if isinstance(first, slice): - return self._stack_values[first] - if first is Ellipsis: - # Ellipsis at start means stack axis not sliced - return self._stack_values - - # Default: preserve all stack_values - return self._stack_values - @property def stack_type(self) -> str: """Type of stack dimension: 'time', 'tilt', 'energy', 'dose', 'focus', or 'generic'.""" @@ -200,6 +179,32 @@ def virtual_detectors(self) -> dict[str, dict]: """Virtual detector configurations for regenerating images.""" return self._virtual_detectors + @classmethod + def from_file(cls, file_path: str, file_type: str | None = None, **kwargs) -> "Dataset5dstem": + """Load Dataset5dstem from a file. + + Parameters + ---------- + file_path : str + Path to data file. + file_type : str | None + File type hint. If None, auto-detect from extension. + **kwargs + Additional arguments passed to read_5dstem (e.g., stack_type). + + Returns + ------- + Dataset5dstem + + Examples + -------- + >>> data = Dataset5dstem.from_file("path/to/data.h5") + >>> data = Dataset5dstem.from_file("path/to/data.h5", stack_type="tilt") + """ + from quantem.core.io.file_readers import read_5dstem + + return read_5dstem(file_path, file_type=file_type, **kwargs) + # ------------------------------------------------------------------------- # Construction # ------------------------------------------------------------------------- @@ -411,14 +416,6 @@ def _get_frame(self, idx: int) -> Dataset4dstem: frame.metadata["r_to_q_rotation_cw_deg"] = self.metadata.get("r_to_q_rotation_cw_deg") frame.metadata["ellipticity"] = self.metadata.get("ellipticity") - # Inherit virtual detector definitions - for name, info in self._virtual_detectors.items(): - frame._virtual_detectors[name] = { - "mask": None, - "mode": info["mode"], - "geometry": info["geometry"], - } - return frame # ------------------------------------------------------------------------- @@ -429,7 +426,7 @@ def get_virtual_image( self, mask: ArrayLike | None = None, mode: str | None = None, - geometry: tuple | None = None, + geometry: tuple | list | None = None, name: str = "virtual_image", attach: bool = True, ) -> Dataset3d: @@ -438,12 +435,22 @@ def get_virtual_image( Parameters ---------- mask : np.ndarray, optional - Custom mask matching diffraction pattern shape. + Custom mask matching diffraction pattern shape (k_row, k_col). mode : str, optional Mask mode: "circle" or "annular". - geometry : tuple, optional - For "circle": ((cy, cx), radius). - For "annular": ((cy, cx), (r_inner, r_outer)). + geometry : tuple or list, optional + Detector geometry in pixels. Format depends on mode: + + For "circle" mode: + - ``(None, radius)`` : Auto-fit center per frame. + - ``((cy, cx), radius)`` : Fixed center for all frames. + - ``[((cy0, cx0), r0), ...]`` : Per-frame geometry list. + + For "annular" mode: + - ``(None, (r_inner, r_outer))`` : Auto-fit center per frame. + - ``((cy, cx), (r_inner, r_outer))`` : Fixed center. + - ``[((cy0, cx0), (r0_in, r0_out)), ...]`` : Per-frame list. + name : str, optional Name for the virtual image. Default: "virtual_image". attach : bool, optional @@ -453,27 +460,66 @@ def get_virtual_image( ------- Dataset3d Virtual image stack with shape (n_frames, scan_row, scan_col). + + Notes + ----- + All geometry values are in pixels. To convert from mrad: + + >>> radius_px = radius_mrad / data.sampling[-1] + + Examples + -------- + Auto-fit center per frame: + + >>> bf = data.get_virtual_image(mode="circle", geometry=(None, 20)) + >>> adf = data.get_virtual_image(mode="annular", geometry=(None, (30, 80))) + + Fixed center for all frames: + + >>> k_center = (data.shape[-2] // 2, data.shape[-1] // 2) + >>> bf = data.get_virtual_image(mode="circle", geometry=(k_center, 20)) + + Per-frame geometry: + + >>> geometries = [(center, 20) for center in fitted_centers] + >>> bf = data.get_virtual_image(mode="circle", geometry=geometries) """ dp_shape = self.array.shape[-2:] + n_frames = len(self) if mask is not None: if mask.shape != dp_shape: raise ValueError(f"Mask shape {mask.shape} != diffraction pattern shape {dp_shape}") - final_mask = mask - elif mode and geometry: - if mode == "circle": - center, radius = geometry - final_mask = create_circle_mask(dp_shape, center, radius) - elif mode == "annular": - center, radii = geometry - final_mask = create_annular_mask(dp_shape, center, radii) + virtual_stack = np.sum(self.array * mask, axis=(-1, -2)) + elif mode and geometry is not None: + # Per-frame geometry list + if isinstance(geometry, list): + if len(geometry) != n_frames: + raise ValueError( + f"geometry list length ({len(geometry)}) must match " + f"number of frames ({n_frames})" + ) + virtual_stack = self._compute_per_frame_virtual(mode, geometry, dp_shape) else: - raise ValueError(f"Unknown mode '{mode}'. Use 'circle' or 'annular'.") + # Single geometry tuple: (center_or_none, radius_or_radii) + center, radius_or_radii = geometry + if center is None: + # Auto-fit center per frame + geometries = self._auto_fit_centers(mode, radius_or_radii) + virtual_stack = self._compute_per_frame_virtual(mode, geometries, dp_shape) + geometry = geometries # Store fitted geometries + else: + # Fixed center for all frames + if mode == "circle": + final_mask = create_circle_mask(dp_shape, center, radius_or_radii) + elif mode == "annular": + final_mask = create_annular_mask(dp_shape, center, radius_or_radii) + else: + raise ValueError(f"Unknown mode '{mode}'. Use 'circle' or 'annular'.") + virtual_stack = np.sum(self.array * final_mask, axis=(-1, -2)) else: raise ValueError("Provide either mask or both mode and geometry") - virtual_stack = np.sum(self.array * final_mask, axis=(-1, -2)) - vi = Dataset3d.from_array( array=virtual_stack, name=name, @@ -486,13 +532,146 @@ def get_virtual_image( if attach: self._virtual_images[name] = vi self._virtual_detectors[name] = { - "mask": final_mask.copy() if mask is not None else None, + "mask": mask.copy() if mask is not None else None, "mode": mode, "geometry": geometry, } return vi + def _auto_fit_centers(self, mode: str, radius_or_radii) -> list: + """Fit probe center for each frame and return list of geometries.""" + geometries = [] + for i in range(len(self)): + dp_mean = np.mean(self.array[i], axis=(0, 1)) + cy, cx, _ = fit_probe_circle(dp_mean, show=False) + geometries.append(((cy, cx), radius_or_radii)) + return geometries + + def _compute_per_frame_virtual(self, mode: str, geometries: list, dp_shape: tuple) -> np.ndarray: + """Compute virtual images with per-frame geometry.""" + virtual_stack = np.zeros((len(self), self.shape[1], self.shape[2]), dtype=self.array.dtype) + for i, geom in enumerate(geometries): + center, radius_or_radii = geom + if mode == "circle": + mask = create_circle_mask(dp_shape, center, radius_or_radii) + elif mode == "annular": + mask = create_annular_mask(dp_shape, center, radius_or_radii) + else: + raise ValueError(f"Unknown mode '{mode}'") + virtual_stack[i] = np.sum(self.array[i] * mask, axis=(-1, -2)) + return virtual_stack + + def show_virtual_images(self, figsize: tuple[int, int] | None = None, **kwargs) -> tuple: + """Display all virtual images stored in the dataset. + + Parameters + ---------- + figsize : tuple[int, int] | None + Figure size. If None, auto-calculated. + **kwargs + Arguments passed to show_2d (cmap, norm, cbar, etc.) + + Returns + ------- + tuple + (fig, axs) from matplotlib. + """ + if not self.virtual_images: + print("No virtual images. Create with get_virtual_image().") + return None, None + + # Each virtual image is Dataset3d - show first frame of each + arrays = [vi.array[0] for vi in self.virtual_images.values()] + titles = [f"{name} (frame 0)" for name in self.virtual_images.keys()] + + n = len(arrays) + if figsize is None: + figsize = (4 * min(n, 4), 4 * ((n + 3) // 4)) + + return show_2d(arrays, title=titles, figax_size=figsize, **kwargs) + + def regenerate_virtual_images(self) -> None: + """Regenerate virtual images from stored detector information.""" + if not self._virtual_detectors: + return + + self._virtual_images.clear() + + for name, info in self._virtual_detectors.items(): + try: + if info["mode"] is not None and info["geometry"] is not None: + self.get_virtual_image( + mode=info["mode"], + geometry=info["geometry"], + name=name, + attach=True, + ) + else: + print(f"Warning: Cannot regenerate '{name}' - insufficient detector info.") + except Exception as e: + print(f"Warning: Failed to regenerate '{name}': {e}") + + def update_virtual_detector( + self, + name: str, + mask: np.ndarray | None = None, + mode: str | None = None, + geometry: tuple | None = None, + ) -> None: + """Update virtual detector and regenerate the corresponding image. + + Parameters + ---------- + name : str + Name of virtual detector to update. + mask : np.ndarray | None + New mask (must match DP dimensions). + mode : str | None + New mode ("circle" or "annular"). + geometry : tuple | None + New geometry. + """ + if name not in self._virtual_detectors: + raise ValueError(f"Detector '{name}' not found. Available: {list(self._virtual_detectors.keys())}") + + self._virtual_detectors[name]["mask"] = mask.copy() if mask is not None else None + self._virtual_detectors[name]["mode"] = mode + self._virtual_detectors[name]["geometry"] = geometry + + self.get_virtual_image(mask=mask, mode=mode, geometry=geometry, name=name, attach=True) + + def clear_virtual_images(self) -> None: + """Clear virtual images while keeping detector information.""" + self._virtual_images.clear() + + def clear_all_virtual_data(self) -> None: + """Clear both virtual images and detector information.""" + self._virtual_images.clear() + self._virtual_detectors.clear() + + def show(self, *args, **kwargs): + """Not implemented for 5D data. + + Raises + ------ + NotImplementedError + Always raised. Use alternative methods for visualization. + + See Also + -------- + show_virtual_images : Display virtual images. + data[i].show() : Show a single 4D frame. + data[i].dp_mean.show() : Show mean diffraction pattern. + """ + raise NotImplementedError( + "show() is not meaningful for 5D data. Use:\n" + " - data[i].show() for a single frame\n" + " - data[i].dp_mean.show() for mean DP\n" + " - data.show_virtual_images() for virtual images\n" + " - data.get_virtual_image(...).show() for virtual image stack" + ) + # ------------------------------------------------------------------------- # Copy # ------------------------------------------------------------------------- diff --git a/tests/datastructures/test_dataset5dstem.py b/tests/datastructures/test_dataset5dstem.py index 6a62a09a..bcfd018b 100644 --- a/tests/datastructures/test_dataset5dstem.py +++ b/tests/datastructures/test_dataset5dstem.py @@ -81,3 +81,58 @@ def test_validation(self): ds2 = Dataset4dstem.from_array(np.random.rand(8, 9, 3, 5)) with pytest.raises(ValueError, match="shape"): Dataset5dstem.from_4dstem([ds1, ds2]) + + def test_virtual_image_management(self, sample_dataset): + """Test virtual image clear, regenerate, update methods.""" + # Create virtual images + sample_dataset.get_virtual_image(mode="circle", geometry=((1, 2), 1), name="bf") + sample_dataset.get_virtual_image(mode="annular", geometry=((1, 2), (1, 2)), name="adf") + assert len(sample_dataset.virtual_images) == 2 + + # Clear images only + sample_dataset.clear_virtual_images() + assert len(sample_dataset.virtual_images) == 0 + assert len(sample_dataset.virtual_detectors) == 2 + + # Regenerate + sample_dataset.regenerate_virtual_images() + assert len(sample_dataset.virtual_images) == 2 + + # Update detector + sample_dataset.update_virtual_detector("bf", mode="circle", geometry=((1, 2), 2)) + assert sample_dataset.virtual_detectors["bf"]["geometry"] == ((1, 2), 2) + + # Clear all + sample_dataset.clear_all_virtual_data() + assert len(sample_dataset.virtual_images) == 0 + assert len(sample_dataset.virtual_detectors) == 0 + + def test_virtual_image_per_frame_geometry(self, sample_dataset): + """Test virtual image with per-frame geometry.""" + geometries = [((1, 2), 1) for _ in range(4)] + vi = sample_dataset.get_virtual_image(mode="circle", geometry=geometries, name="bf_perframe") + assert vi.shape == (4, 6, 7) + assert len(sample_dataset.virtual_detectors["bf_perframe"]["geometry"]) == 4 + + def test_virtual_image_auto_fit_center(self): + """Test virtual image with auto-fit center (requires realistic DP).""" + # Create data with a clear probe pattern + n_frames, scan_r, scan_c, k_r, k_c = 3, 4, 4, 32, 32 + array = np.zeros((n_frames, scan_r, scan_c, k_r, k_c), dtype=np.float32) + # Add circular probe at center + cy, cx = k_r // 2, k_c // 2 + y, x = np.ogrid[:k_r, :k_c] + mask = (y - cy) ** 2 + (x - cx) ** 2 <= 8 ** 2 + array[:, :, :, mask] = 1.0 + + data = Dataset5dstem.from_array(array, stack_type="time") + vi = data.get_virtual_image(mode="circle", geometry=(None, 5), name="bf_auto") + assert vi.shape == (3, 4, 4) + assert isinstance(data.virtual_detectors["bf_auto"]["geometry"], list) + assert len(data.virtual_detectors["bf_auto"]["geometry"]) == 3 + + def test_show_raises_error(self, sample_dataset): + """Test that show() raises NotImplementedError with helpful message.""" + with pytest.raises(NotImplementedError, match="not meaningful for 5D"): + sample_dataset.show() + From c8904be7812a129d63f20c9919e356e7f3c2e34b Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sun, 18 Jan 2026 11:29:05 -0800 Subject: [PATCH 10/12] Remove unnecessary blank lines in dataset5dstem.py --- src/quantem/core/datastructures/dataset5dstem.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py index d67f2676..2ce9e83c 100644 --- a/src/quantem/core/datastructures/dataset5dstem.py +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -81,12 +81,9 @@ def __init__( metadata=metadata, _token=_token, ) - if stack_type not in STACK_TYPES: raise ValueError(f"stack_type must be one of {STACK_TYPES}, got '{stack_type}'") - self._stack_type = stack_type - if stack_values is not None: stack_values = np.asarray(stack_values) if len(stack_values) != self.shape[0]: @@ -124,11 +121,9 @@ def __getitem__(self, idx) -> "Dataset4dstem | Dataset5dstem": # Handle integer indexing (including numpy integers) if isinstance(idx, (int, np.integer)): return self._get_frame(int(idx)) - # Handle tuple where first element is int (e.g., data[0, ...]) if isinstance(idx, tuple) and len(idx) > 0 and isinstance(idx[0], (int, np.integer)): return self._get_frame(int(idx[0]))[idx[1:]] - # Reject advanced indexing on stack axis (lists, arrays, boolean masks) if isinstance(idx, (list, np.ndarray)): raise TypeError( @@ -140,10 +135,8 @@ def __getitem__(self, idx) -> "Dataset4dstem | Dataset5dstem": "Advanced indexing with lists/arrays on stack axis is not supported. " "Use integer indexing or slices instead." ) - # Get result from base class slicing result = super().__getitem__(idx) - # If result is still 5D, wrap back into Dataset5dstem with preserved metadata if result.array.ndim == 5: return self.from_array( @@ -208,7 +201,6 @@ def from_file(cls, file_path: str, file_type: str | None = None, **kwargs) -> "D # ------------------------------------------------------------------------- # Construction # ------------------------------------------------------------------------- - @classmethod def from_array( cls, @@ -364,7 +356,6 @@ def from_4dstem( # ------------------------------------------------------------------------- # Stack operations # ------------------------------------------------------------------------- - def stack_mean(self) -> Dataset4dstem: """Average over the stack axis. Returns Dataset4dstem.""" return self._reduce_stack(np.mean, "mean") @@ -421,7 +412,6 @@ def _get_frame(self, idx: int) -> Dataset4dstem: # ------------------------------------------------------------------------- # Virtual imaging # ------------------------------------------------------------------------- - def get_virtual_image( self, mask: ArrayLike | None = None, @@ -675,7 +665,6 @@ def show(self, *args, **kwargs): # ------------------------------------------------------------------------- # Copy # ------------------------------------------------------------------------- - def _copy_custom_attributes(self, new_dataset) -> None: """Copy Dataset5dstem-specific attributes.""" super()._copy_custom_attributes(new_dataset) From 55bb494c9d8a30d15b9e5a1d780de5fc9331ce15 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sun, 18 Jan 2026 12:20:26 -0800 Subject: [PATCH 11/12] Add full auto-detection for virtual imaging and update dimension naming - Add geometry=None support for full auto-detection (center + radius) - Update Dataset4dstem docstring: scan_y/x -> scan_row/col, dp_y/x -> k_row/col - Add test for full auto-detection virtual imaging --- .../core/datastructures/dataset4dstem.py | 2 +- .../core/datastructures/dataset5dstem.py | 21 +++++++++++++----- tests/datastructures/test_dataset5dstem.py | 22 +++++++++++++++++++ 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index d215b029..80d42dba 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -17,7 +17,7 @@ class Dataset4dstem(Dataset4d): """A 4D-STEM dataset class that inherits from Dataset4d. This class represents a 4D scanning transmission electron microscopy (STEM) dataset, - where the data consists of a 4D array with dimensions (scan_y, scan_x, dp_y, dp_x). + where the data consists of a 4D array with dimensions (scan_row, scan_col, k_row, k_col). The first two dimensions represent real space scanning positions, while the latter two dimensions represent reciprocal space diffraction patterns. diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py index 2ce9e83c..13739f75 100644 --- a/src/quantem/core/datastructures/dataset5dstem.py +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -481,9 +481,14 @@ def get_virtual_image( if mask.shape != dp_shape: raise ValueError(f"Mask shape {mask.shape} != diffraction pattern shape {dp_shape}") virtual_stack = np.sum(self.array * mask, axis=(-1, -2)) - elif mode and geometry is not None: + elif mode is not None: + # Full auto-detection: geometry=None + if geometry is None: + geometries = self._auto_fit_centers(mode, None) + virtual_stack = self._compute_per_frame_virtual(mode, geometries, dp_shape) + geometry = geometries # Per-frame geometry list - if isinstance(geometry, list): + elif isinstance(geometry, list): if len(geometry) != n_frames: raise ValueError( f"geometry list length ({len(geometry)}) must match " @@ -494,7 +499,7 @@ def get_virtual_image( # Single geometry tuple: (center_or_none, radius_or_radii) center, radius_or_radii = geometry if center is None: - # Auto-fit center per frame + # Auto-fit center per frame (radius may also be None for auto) geometries = self._auto_fit_centers(mode, radius_or_radii) virtual_stack = self._compute_per_frame_virtual(mode, geometries, dp_shape) geometry = geometries # Store fitted geometries @@ -530,12 +535,16 @@ def get_virtual_image( return vi def _auto_fit_centers(self, mode: str, radius_or_radii) -> list: - """Fit probe center for each frame and return list of geometries.""" + """Fit probe center for each frame and return list of geometries. + + If radius_or_radii is None, also auto-detect the radius. + """ geometries = [] for i in range(len(self)): dp_mean = np.mean(self.array[i], axis=(0, 1)) - cy, cx, _ = fit_probe_circle(dp_mean, show=False) - geometries.append(((cy, cx), radius_or_radii)) + cy, cx, detected_radius = fit_probe_circle(dp_mean, show=False) + radius = detected_radius if radius_or_radii is None else radius_or_radii + geometries.append(((cy, cx), radius)) return geometries def _compute_per_frame_virtual(self, mode: str, geometries: list, dp_shape: tuple) -> np.ndarray: diff --git a/tests/datastructures/test_dataset5dstem.py b/tests/datastructures/test_dataset5dstem.py index bcfd018b..e5ffcec8 100644 --- a/tests/datastructures/test_dataset5dstem.py +++ b/tests/datastructures/test_dataset5dstem.py @@ -131,6 +131,28 @@ def test_virtual_image_auto_fit_center(self): assert isinstance(data.virtual_detectors["bf_auto"]["geometry"], list) assert len(data.virtual_detectors["bf_auto"]["geometry"]) == 3 + def test_virtual_image_full_auto(self): + """Test virtual image with full auto-detection (center and radius).""" + # Create data with a clear probe pattern + n_frames, scan_r, scan_c, k_r, k_c = 3, 4, 4, 32, 32 + array = np.zeros((n_frames, scan_r, scan_c, k_r, k_c), dtype=np.float32) + # Add circular probe at center with radius 8 + cy, cx = k_r // 2, k_c // 2 + y, x = np.ogrid[:k_r, :k_c] + mask = (y - cy) ** 2 + (x - cx) ** 2 <= 8 ** 2 + array[:, :, :, mask] = 1.0 + + data = Dataset5dstem.from_array(array, stack_type="time") + # geometry=None triggers full auto-detection + vi = data.get_virtual_image(mode="circle", geometry=None, name="bf_full_auto") + assert vi.shape == (3, 4, 4) + geoms = data.virtual_detectors["bf_full_auto"]["geometry"] + assert isinstance(geoms, list) + assert len(geoms) == 3 + # Check that radius was auto-detected (should be close to 8) + for center, radius in geoms: + assert 6 < radius < 10 # Allow some tolerance + def test_show_raises_error(self, sample_dataset): """Test that show() raises NotImplementedError with helpful message.""" with pytest.raises(NotImplementedError, match="not meaningful for 5D"): From b0be6912f49b846dc75536ef8fb8d0f696052f5b Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sun, 18 Jan 2026 12:22:46 -0800 Subject: [PATCH 12/12] Remove unnecessary blank lines in file_readers.py --- src/quantem/core/io/file_readers.py | 33 ++++++----------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index 65454118..29762548 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -156,7 +156,6 @@ def read_5dstem( try: with h5py.File(file_path, "r") as f: if "data" in f and "properties" in f["data"].attrs: - # Nion Swift format detected return _read_nion_swift_5dstem(file_path, stack_type, **kwargs) except Exception: pass # Fall through to rsciio @@ -164,47 +163,35 @@ def read_5dstem( # Fall back to rosettasciio if file_type is None: file_type = file_path.suffix.lower().lstrip(".") - file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader data_list = file_reader(file_path) # Find first 5D dataset five_d_datasets = [(i, d) for i, d in enumerate(data_list) if d["data"].ndim == 5] - if len(five_d_datasets) == 0: print(f"No 5D datasets found in {file_path}. Available datasets:") for i, d in enumerate(data_list): print(f" Dataset {i}: shape {d['data'].shape}, ndim={d['data'].ndim}") raise ValueError("No 5D dataset found in file") - dataset_index, imported_data = five_d_datasets[0] - if len(data_list) > 1: print( f"File contains {len(data_list)} dataset(s). Using dataset {dataset_index} " f"with shape {imported_data['data'].shape}" ) + # Extract calibrations imported_axes = imported_data["axes"] - - sampling = kwargs.pop( - "sampling", - [ax["scale"] for ax in imported_axes], - ) - origin = kwargs.pop( - "origin", - [ax["offset"] for ax in imported_axes], - ) + sampling = kwargs.pop("sampling", [ax["scale"] for ax in imported_axes]) + origin = kwargs.pop("origin", [ax["offset"] for ax in imported_axes]) units = kwargs.pop( "units", ["pixels" if ax["units"] == "1" else ax["units"] for ax in imported_axes], ) - - # Determine stack type if stack_type == "auto": stack_type = "generic" - dataset = Dataset5dstem.from_array( + return Dataset5dstem.from_array( array=imported_data["data"], sampling=sampling, origin=origin, @@ -213,8 +200,6 @@ def read_5dstem( **kwargs, ) - return dataset - def _read_nion_swift_5dstem( file_path: str | PathLike, @@ -242,7 +227,6 @@ def _read_nion_swift_5dstem( with h5py.File(file_path, "r") as f: data = f["data"][:] props = json.loads(f["data"].attrs["properties"]) - if data.ndim != 5: raise ValueError(f"Expected 5D data, got {data.ndim}D with shape {data.shape}") @@ -259,16 +243,13 @@ def _read_nion_swift_5dstem( # Determine stack type from metadata if stack_type == "auto": - if props.get("is_sequence", False): - stack_type = "time" - else: - stack_type = "generic" + stack_type = "time" if props.get("is_sequence", False) else "generic" # Get intensity calibration intensity_cal = props.get("intensity_calibration", {}) signal_units = intensity_cal.get("units", "arb. units") or "arb. units" - dataset = Dataset5dstem.from_array( + return Dataset5dstem.from_array( array=data, origin=origin, sampling=sampling, @@ -278,8 +259,6 @@ def _read_nion_swift_5dstem( **kwargs, ) - return dataset - def read_2d( file_path: str | PathLike,