diff --git a/.gitignore b/.gitignore index d87d85c8..4dbaef1f 100644 --- a/.gitignore +++ b/.gitignore @@ -191,3 +191,4 @@ ipynb-playground/ # widget (JS build artifacts) node_modules/ widget/src/quantem/widget/static/ +*.ipynb diff --git a/src/quantem/core/datastructures/__init__.py b/src/quantem/core/datastructures/__init__.py index dfb5b47a..4f366b9a 100644 --- a/src/quantem/core/datastructures/__init__.py +++ b/src/quantem/core/datastructures/__init__.py @@ -1,7 +1,9 @@ from quantem.core.datastructures.dataset import Dataset as Dataset from quantem.core.datastructures.vector import Vector as Vector -from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem -from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d -from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d from quantem.core.datastructures.dataset2d import Dataset2d as Dataset2d +from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d +from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d +from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem +from quantem.core.datastructures.dataset5d import Dataset5d as Dataset5d +from quantem.core.datastructures.dataset5dstem import Dataset5dstem as Dataset5dstem diff --git a/src/quantem/core/datastructures/dataset3d.py b/src/quantem/core/datastructures/dataset3d.py index 53d6d02c..4f36ef04 100644 --- a/src/quantem/core/datastructures/dataset3d.py +++ b/src/quantem/core/datastructures/dataset3d.py @@ -68,6 +68,50 @@ def from_array( units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", ) -> Self: + """Create a Dataset3d from a 3D array. + + Parameters + ---------- + array : NDArray | Any + 3D array with shape (n_frames, height, width) + name : str | None + Dataset name. Default: "3D dataset" + origin : NDArray | tuple | list | float | int | None + Origin for each dimension. Default: [0, 0, 0] + sampling : NDArray | tuple | list | float | int | None + Sampling for each dimension. Default: [1, 1, 1] + units : list[str] | tuple | list | None + Units for each dimension. Default: ["index", "pixels", "pixels"] + signal_units : str + Units for array values. Default: "arb. units" + + Returns + ------- + Dataset3d + + Examples + -------- + >>> import numpy as np + >>> from quantem.core.datastructures import Dataset3d + >>> arr = np.random.rand(10, 64, 64) + >>> data = Dataset3d.from_array(arr) + >>> data.shape + (10, 64, 64) + + With calibration: + + >>> data = Dataset3d.from_array( + ... arr, + ... sampling=[1, 0.1, 0.1], + ... units=["frame", "nm", "nm"], + ... ) + + Visualize: + + >>> data.show() # all frames in grid + >>> data.show(index=0) # single frame + >>> data.show(ncols=2) # 2 columns + """ array = ensure_valid_array(array, ndim=3) return cls( array=array, @@ -90,7 +134,37 @@ def from_shape( units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", ) -> Self: - """Create a new Dataset3d filled with a constant value.""" + """Create a Dataset3d filled with a constant value. + + Parameters + ---------- + shape : tuple[int, int, int] + Shape (n_frames, height, width) + name : str + Dataset name. Default: "constant 3D dataset" + fill_value : float + Value to fill array with. Default: 0.0 + origin : NDArray | tuple | list | float | int | None + Origin for each dimension + sampling : NDArray | tuple | list | float | int | None + Sampling for each dimension + units : list[str] | tuple | list | None + Units for each dimension + signal_units : str + Units for array values + + Returns + ------- + Dataset3d + + Examples + -------- + >>> data = Dataset3d.from_shape((10, 64, 64)) + >>> data.shape + (10, 64, 64) + >>> data.array.max() + 0.0 + """ array = np.full(shape, fill_value, dtype=np.float32) return cls.from_array( array=array, @@ -107,24 +181,72 @@ def to_dataset2d(self): def show( self, - index: int = 0, + index: int | None = None, scalebar: ScalebarConfig | bool = True, title: str | None = None, + suptitle: str | None = None, + ncols: int = 4, + returnfig: bool = False, **kwargs, ): """ - Display a 2D slice of the 3D dataset. + Display 2D slices of the 3D dataset. Parameters ---------- - index : int - Index of the 2D slice to display (along axis 0). - scalebar: ScalebarConfig or bool - If True, displays scalebar - title: str - Title of Dataset - **kwargs: dict - Keyword arguments for show_2d + index : int | None + Index of the 2D slice to display. If None, shows all slices in a grid. + scalebar : ScalebarConfig or bool + If True, displays scalebar. + title : str | None + Title for the plot. If None, uses "Frame 0", "Frame 1", etc. + suptitle : str | None + Figure super title displayed above all subplots. + ncols : int + Maximum columns when showing all slices. Default: 4. + returnfig : bool + If True, returns (fig, axes). Default: False. + **kwargs : dict + Keyword arguments for show_2d (cmap, cbar, norm, etc.). + + Examples + -------- + >>> data.show() # show all frames in grid + >>> data.show(index=0) # show single frame + >>> data.show(ncols=3) # 3 columns + >>> data.show(suptitle="Diffraction patterns") # with super title + >>> fig, axes = data.show(returnfig=True) # get figure for customization """ + from quantem.core.visualization import show_2d + + if index is not None: + # Handle negative index + actual_index = index if index >= 0 else self.shape[0] + index + default_title = title if title is not None else f"Frame {actual_index}" + result = self[index].show(scalebar=scalebar, title=default_title, **kwargs) + return result if returnfig else None + + # Show all frames in a grid + n = self.shape[0] + nrows = (n + ncols - 1) // ncols + arrays = [] + titles = [] + for row in range(nrows): + row_arrays = [] + row_titles = [] + for col in range(ncols): + i = row * ncols + col + if i < n: + row_arrays.append(self.array[i]) + row_titles.append(f"Frame {i}" if title is None else f"{title} {i}") + else: + row_arrays.append(np.zeros_like(self.array[0])) + row_titles.append("") + arrays.append(row_arrays) + titles.append(row_titles) - return self[index].show(scalebar=scalebar, title=title, **kwargs) + fig, axes = show_2d(arrays, scalebar=scalebar, title=titles, **kwargs) + if suptitle is not None: + fig.suptitle(suptitle, fontsize=14) + fig.subplots_adjust(top=0.92) + return (fig, axes) if returnfig else None diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index 28328636..80d42dba 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -7,6 +7,7 @@ from quantem.core.datastructures.dataset2d import Dataset2d from quantem.core.datastructures.dataset4d import Dataset4d +from quantem.core.utils.masks import create_annular_mask, create_circle_mask from quantem.core.utils.validators import ensure_valid_array from quantem.core.visualization import show_2d from quantem.core.visualization.visualization_utils import ScalebarConfig @@ -16,7 +17,7 @@ class Dataset4dstem(Dataset4d): """A 4D-STEM dataset class that inherits from Dataset4d. This class represents a 4D scanning transmission electron microscopy (STEM) dataset, - where the data consists of a 4D array with dimensions (scan_y, scan_x, dp_y, dp_x). + where the data consists of a 4D array with dimensions (scan_row, scan_col, k_row, k_col). The first two dimensions represent real space scanning positions, while the latter two dimensions represent reciprocal space diffraction patterns. @@ -90,6 +91,17 @@ def __init__( self._virtual_images = {} self._virtual_detectors = {} # Store detector information for regeneration + def __repr__(self) -> str: + return f"Dataset4dstem(shape={self.shape}, dtype={self.array.dtype})" + + def __str__(self) -> str: + return ( + f"Dataset4dstem '{self.name}'\n" + f" shape: {self.shape}\n" + f" scan sampling: {self.sampling[:2]} {self.units[:2]}\n" + f" k sampling: {self.sampling[2:]} {self.units[2:]}" + ) + @classmethod def from_file(cls, file_path: str, file_type: str) -> "Dataset4dstem": """ @@ -365,6 +377,7 @@ def get_virtual_image( final_mask = mask elif mode is not None and geometry is not None: # Create mask from mode and geometry + dp_shape = self.array.shape[-2:] if mode == "circle": if ( len(geometry) != 2 @@ -373,14 +386,14 @@ def get_virtual_image( ): raise ValueError("For circle mode, geometry must be ((cy, cx), r)") center, radius = geometry - final_mask = self._create_circle_mask(center, radius) + final_mask = create_circle_mask(dp_shape, center, radius) elif mode == "annular": if len(geometry) != 2 or len(geometry[0]) != 2 or len(geometry[1]) != 2: raise ValueError( "For annular mode, geometry must be ((cy, cx), (r_inner, r_outer))" ) center, radii = geometry - final_mask = self._create_annular_mask(center, radii) + final_mask = create_annular_mask(dp_shape, center, radii) else: raise ValueError( f"Unknown mode '{mode}'. Supported modes are 'circle' and 'annular'" @@ -466,59 +479,6 @@ def get_virtual_image( return virtual_image_dataset - def _create_circle_mask(self, center: tuple[float, float], radius: float) -> np.ndarray: - """ - Create a circular mask for virtual image formation. - - Parameters - ---------- - center : tuple[float, float] - Center coordinates (cy, cx) of the circle - radius : float - Radius of the circle - - Returns - ------- - np.ndarray - Boolean mask with True inside the circle - """ - cy, cx = center - dp_shape = self.array.shape[-2:] # Get diffraction pattern dimensions - y, x = np.ogrid[: dp_shape[0], : dp_shape[1]] - - # Calculate distance from center - distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) - - return distance <= radius - - def _create_annular_mask( - self, center: tuple[float, float], radii: tuple[float, float] - ) -> np.ndarray: - """ - Create an annular (ring-shaped) mask for virtual image formation. - - Parameters - ---------- - center : tuple[float, float] - Center coordinates (cy, cx) of the annulus - radii : tuple[float, float] - Inner and outer radii (r_inner, r_outer) of the annulus - - Returns - ------- - np.ndarray - Boolean mask with True inside the annular region - """ - cy, cx = center - r_inner, r_outer = radii - dp_shape = self.array.shape[-2:] # Get diffraction pattern dimensions - y, x = np.ogrid[: dp_shape[0], : dp_shape[1]] - - # Calculate distance from center - distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) - - return (distance >= r_inner) & (distance <= r_outer) - def show_virtual_images(self, figsize: tuple[int, int] | None = None, **kwargs) -> tuple: """ Display all virtual images stored in the dataset using show_2d. diff --git a/src/quantem/core/datastructures/dataset5d.py b/src/quantem/core/datastructures/dataset5d.py new file mode 100644 index 00000000..62bc6554 --- /dev/null +++ b/src/quantem/core/datastructures/dataset5d.py @@ -0,0 +1,180 @@ +from typing import Any, Self + +import numpy as np +from numpy.typing import NDArray + +from quantem.core.datastructures.dataset import Dataset +from quantem.core.utils.validators import ensure_valid_array +from quantem.core.visualization.visualization_utils import ScalebarConfig + + +@Dataset.register_dimension(5) +class Dataset5d(Dataset): + """5D dataset class that inherits from Dataset. + + This class represents 5D stacks of data, such as time-series or tilt-series + of 4D-STEM experiments. + + The data consists of a 5D array with dimensions (stack, scan_row, scan_col, k_row, k_col). + The first dimension represents the stack axis (time, tilt, defocus, etc.), + dimensions 1-2 represent real space scanning positions, and dimensions 3-4 + represent reciprocal space diffraction patterns. + + Attributes + ---------- + None beyond base Dataset. + """ + + def __init__( + self, + array: NDArray | Any, + name: str, + origin: NDArray | tuple | list | float | int, + sampling: NDArray | tuple | list | float | int, + units: list[str] | tuple | list, + signal_units: str = "arb. units", + metadata: dict = {}, + _token: object | None = None, + ): + """Initialize a 5D dataset. + + Parameters + ---------- + array : NDArray | Any + The underlying 5D array data + name : str + A descriptive name for the dataset + origin : NDArray | tuple | list | float | int + The origin coordinates for each dimension + sampling : NDArray | tuple | list | float | int + The sampling rate/spacing for each dimension + units : list[str] | tuple | list + Units for each dimension + signal_units : str, optional + Units for the array values, by default "arb. units" + metadata : dict, optional + Additional metadata, by default {} + _token : object | None, optional + Token to prevent direct instantiation, by default None + """ + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + metadata=metadata, + _token=_token, + ) + + @classmethod + def from_array( + cls, + array: NDArray | Any, + name: str | None = None, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + ) -> Self: + """Create a new Dataset5d from an array. + + Parameters + ---------- + array : NDArray | Any + The underlying 5D array data + name : str | None, optional + A descriptive name for the dataset. If None, defaults to "5D dataset" + origin : NDArray | tuple | list | float | int | None, optional + The origin coordinates for each dimension. If None, defaults to zeros + sampling : NDArray | tuple | list | float | int | None, optional + The sampling rate/spacing for each dimension. If None, defaults to ones + units : list[str] | tuple | list | None, optional + Units for each dimension. If None, defaults to ["index", "pixels", "pixels", "pixels", "pixels"] + signal_units : str, optional + Units for the array values, by default "arb. units" + + Returns + ------- + Dataset5d + A new Dataset5d instance + """ + array = ensure_valid_array(array, ndim=5) + return cls( + array=array, + name=name if name is not None else "5D dataset", + origin=origin if origin is not None else np.zeros(5), + sampling=sampling if sampling is not None else np.ones(5), + units=units if units is not None else ["index", "pixels", "pixels", "pixels", "pixels"], + signal_units=signal_units, + _token=cls._token, + ) + + @classmethod + def from_shape( + cls, + shape: tuple[int, int, int, int, int], + name: str = "constant 5D dataset", + fill_value: float = 0.0, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + ) -> Self: + """Create a new Dataset5d filled with a constant value. + + Parameters + ---------- + shape : tuple[int, int, int, int, int] + Shape of the 5D array + name : str, optional + Name for the dataset, by default "constant 5D dataset" + fill_value : float, optional + Value to fill the array with, by default 0.0 + origin : NDArray | tuple | list | float | int | None, optional + Origin coordinates for each dimension + sampling : NDArray | tuple | list | float | int | None, optional + Sampling rate for each dimension + units : list[str] | tuple | list | None, optional + Units for each dimension + signal_units : str, optional + Units for the array values, by default "arb. units" + + Returns + ------- + Dataset5d + A new Dataset5d instance filled with the specified value + """ + array = np.full(shape, fill_value, dtype=np.float32) + return cls.from_array( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + ) + + def show( + self, + index: tuple[int, int, int] = (0, 0, 0), + scalebar: ScalebarConfig | bool = True, + title: str | None = None, + **kwargs, + ): + """ + Display a 2D slice of the 5D dataset. + + Parameters + ---------- + index : tuple[int, int, int] + 3D index of the 2D slice to display (along axes (0, 1, 2)). + scalebar : ScalebarConfig or bool + If True, displays scalebar + title : str + Title of Dataset + **kwargs : dict + Keyword arguments for show_2d + """ + return self[index].show(scalebar=scalebar, title=title, **kwargs) diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py new file mode 100644 index 00000000..13739f75 --- /dev/null +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -0,0 +1,686 @@ +"""5D-STEM dataset class for time series, tilt series, and other stacked 4D-STEM data.""" + +from typing import Iterator, Self + +import numpy as np +from numpy.typing import ArrayLike, NDArray + +from quantem.core.datastructures.dataset3d import Dataset3d +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.datastructures.dataset5d import Dataset5d +from quantem.core.utils.diffractive_imaging_utils import fit_probe_circle +from quantem.core.utils.masks import create_annular_mask, create_circle_mask +from quantem.core.utils.validators import ensure_valid_array +from quantem.core.visualization import show_2d + +STACK_TYPES = ("time", "tilt", "energy", "dose", "focus", "generic") + + +class Dataset5dstem(Dataset5d): + """5D-STEM dataset with dimensions (stack, scan_row, scan_col, k_row, k_col). + + The stack axis represents time frames, tilt angles, defocus values, etc. + Dimensions 1-2 are real-space scan positions; dimensions 3-4 are reciprocal-space + diffraction patterns. + + Parameters + ---------- + stack_type : str + Type of stack: "time", "tilt", "energy", "dose", "focus", or "generic". + stack_values : NDArray | None + Explicit values for the stack dimension (e.g., timestamps, tilt angles). + + Examples + -------- + >>> data = read_5dstem("path/to/file.h5") # shape (4, 6, 7, 3, 5) + >>> len(data) # number of frames -> 4 + + Indexing (integer removes dimension, slice keeps it): + + >>> data[2] # -> Dataset4dstem (6, 7, 3, 5) one frame + >>> data[1:3] # -> Dataset5dstem (2, 6, 7, 3, 5) substack + >>> data[:, 4, 1] # -> Dataset3d (4, 3, 5) one scan position, all frames + >>> data[:, 1:5, 2:6] # -> Dataset5dstem (4, 4, 4, 3, 5) scan region crop + >>> data[..., 0:2, 1:4] # -> Dataset5dstem (4, 6, 7, 2, 3) k-space crop + + Stack operations (reduce over stack axis): + + >>> data.stack_mean() # -> Dataset4dstem + >>> data.stack_std() # -> Dataset4dstem + + Virtual imaging: + + >>> vi = data.get_virtual_image(mode="circle", geometry=((1, 2), 1)) + >>> vi.shape # -> (4, 6, 7) = (frames, scan_row, scan_col) + """ + + def __init__( + self, + array: ArrayLike, + name: str, + origin: ArrayLike, + sampling: ArrayLike, + units: list[str], + signal_units: str = "arb. units", + metadata: dict | None = None, + stack_type: str = "generic", + stack_values: ArrayLike | None = None, + _token: object | None = None, + ): + metadata = metadata or {} + for key in ("r_to_q_rotation_cw_deg", "ellipticity"): + metadata.setdefault(key, None) + + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + metadata=metadata, + _token=_token, + ) + if stack_type not in STACK_TYPES: + raise ValueError(f"stack_type must be one of {STACK_TYPES}, got '{stack_type}'") + self._stack_type = stack_type + if stack_values is not None: + stack_values = np.asarray(stack_values) + if len(stack_values) != self.shape[0]: + raise ValueError( + f"stack_values length ({len(stack_values)}) must match " + f"number of frames ({self.shape[0]})" + ) + self._stack_values = stack_values + self._virtual_images: dict[str, Dataset3d] = {} + self._virtual_detectors: dict[str, dict] = {} + + def __repr__(self) -> str: + return ( + f"Dataset5dstem(shape={self.shape}, dtype={self.array.dtype}, " + f"stack_type='{self._stack_type}')" + ) + + def __str__(self) -> str: + return ( + f"Dataset5dstem '{self.name}'\n" + f" shape: {self.shape} ({len(self)} frames)\n" + f" stack_type: '{self._stack_type}'\n" + f" scan sampling: {self.sampling[1:3]} {self.units[1:3]}\n" + f" k sampling: {self.sampling[3:]} {self.units[3:]}" + ) + + def __len__(self) -> int: + return self.shape[0] + + def __iter__(self) -> Iterator[Dataset4dstem]: + for i in range(len(self)): + yield self._get_frame(i) + + def __getitem__(self, idx) -> "Dataset4dstem | Dataset5dstem": + # Handle integer indexing (including numpy integers) + if isinstance(idx, (int, np.integer)): + return self._get_frame(int(idx)) + # Handle tuple where first element is int (e.g., data[0, ...]) + if isinstance(idx, tuple) and len(idx) > 0 and isinstance(idx[0], (int, np.integer)): + return self._get_frame(int(idx[0]))[idx[1:]] + # Reject advanced indexing on stack axis (lists, arrays, boolean masks) + if isinstance(idx, (list, np.ndarray)): + raise TypeError( + "Advanced indexing with lists/arrays on stack axis is not supported. " + "Use integer indexing or slices instead." + ) + if isinstance(idx, tuple) and len(idx) > 0 and isinstance(idx[0], (list, np.ndarray)): + raise TypeError( + "Advanced indexing with lists/arrays on stack axis is not supported. " + "Use integer indexing or slices instead." + ) + # Get result from base class slicing + result = super().__getitem__(idx) + # If result is still 5D, wrap back into Dataset5dstem with preserved metadata + if result.array.ndim == 5: + return self.from_array( + array=result.array, + name=result.name, + origin=result.origin, + sampling=result.sampling, + units=result.units, + signal_units=result.signal_units, + stack_type=self._stack_type, + stack_values=None, # Don't try to slice stack_values + ) + + return result + + @property + def stack_type(self) -> str: + """Type of stack dimension: 'time', 'tilt', 'energy', 'dose', 'focus', or 'generic'.""" + return self._stack_type + + @property + def stack_values(self) -> NDArray | None: + """Explicit values for the stack dimension, or None if using indices.""" + return self._stack_values + + @property + def virtual_images(self) -> dict[str, Dataset3d]: + """Cached virtual image stacks, keyed by name.""" + return self._virtual_images + + @property + def virtual_detectors(self) -> dict[str, dict]: + """Virtual detector configurations for regenerating images.""" + return self._virtual_detectors + + @classmethod + def from_file(cls, file_path: str, file_type: str | None = None, **kwargs) -> "Dataset5dstem": + """Load Dataset5dstem from a file. + + Parameters + ---------- + file_path : str + Path to data file. + file_type : str | None + File type hint. If None, auto-detect from extension. + **kwargs + Additional arguments passed to read_5dstem (e.g., stack_type). + + Returns + ------- + Dataset5dstem + + Examples + -------- + >>> data = Dataset5dstem.from_file("path/to/data.h5") + >>> data = Dataset5dstem.from_file("path/to/data.h5", stack_type="tilt") + """ + from quantem.core.io.file_readers import read_5dstem + + return read_5dstem(file_path, file_type=file_type, **kwargs) + + # ------------------------------------------------------------------------- + # Construction + # ------------------------------------------------------------------------- + @classmethod + def from_array( + cls, + array: ArrayLike, + name: str | None = None, + origin: ArrayLike | None = None, + sampling: ArrayLike | None = None, + units: list[str] | None = None, + signal_units: str = "arb. units", + stack_type: str = "generic", + stack_values: ArrayLike | None = None, + ) -> Self: + """Create Dataset5dstem from a 5D array. + + Parameters + ---------- + array : array-like + 5D array with shape (stack, scan_row, scan_col, k_row, k_col). + name : str, optional + Dataset name. Default: "5D-STEM dataset". + origin : array-like, optional + Origin for each dimension (5 elements). Default: zeros. + sampling : array-like, optional + Sampling for each dimension (5 elements). Default: ones. + units : list[str], optional + Units for each dimension (5 elements). Default: ["pixels", ...]. + signal_units : str, optional + Units for intensity values. Default: "arb. units". + stack_type : str, optional + Type of stack dimension. Default: "generic". + stack_values : array-like, optional + Explicit values for stack positions (e.g., times, angles). + + Returns + ------- + Dataset5dstem + + Examples + -------- + Basic usage: + + >>> import numpy as np + >>> arr = np.random.rand(10, 256, 256, 64, 64) + >>> data = Dataset5dstem.from_array(arr, stack_type="time") + >>> data.shape + (10, 256, 256, 64, 64) + + With calibrations: + + >>> data = Dataset5dstem.from_array( + ... arr, + ... stack_type="tilt", + ... stack_values=np.linspace(-60, 60, 10), # tilt angles in degrees + ... sampling=[1, 0.5, 0.5, 0.01, 0.01], + ... units=["deg", "nm", "nm", "1/nm", "1/nm"], + ... ) + """ + array = ensure_valid_array(array, ndim=5) + + origin_5d = np.zeros(5) if origin is None else np.asarray(origin) + sampling_5d = np.ones(5) if sampling is None else np.asarray(sampling) + units_5d = ["pixels"] * 5 if units is None else list(units) + + return cls( + array=array, + name=name or "5D-STEM dataset", + origin=origin_5d, + sampling=sampling_5d, + units=units_5d, + signal_units=signal_units, + stack_type=stack_type, + stack_values=stack_values, + _token=cls._token, + ) + + @classmethod + def from_4dstem( + cls, + datasets: list[Dataset4dstem], + stack_type: str = "generic", + stack_values: NDArray | None = None, + name: str | None = None, + ) -> Self: + """Create Dataset5dstem by stacking multiple Dataset4dstem objects. + + Parameters + ---------- + datasets : list[Dataset4dstem] + List of 4D-STEM datasets to stack. Must have identical shapes. + stack_type : str, optional + Type of stack dimension. Default: "generic". + stack_values : NDArray, optional + Explicit values for stack positions. + name : str, optional + Dataset name. + + Returns + ------- + Dataset5dstem + + Examples + -------- + Stack multiple 4D-STEM datasets into a tilt series: + + >>> from quantem.core.io import read_4dstem + >>> frames = [read_4dstem(f"tilt_{i:02d}.h5") for i in range(10)] + >>> tilt_series = Dataset5dstem.from_4dstem( + ... frames, + ... stack_type="tilt", + ... stack_values=np.linspace(-60, 60, 10), + ... ) + >>> tilt_series.shape + (10, 256, 256, 128, 128) + + Stack synthetic data: + + >>> datasets = [Dataset4dstem.from_array(np.random.rand(64, 64, 32, 32)) for _ in range(5)] + >>> data = Dataset5dstem.from_4dstem(datasets, stack_type="time") + """ + if not datasets: + raise ValueError("datasets list cannot be empty") + + first = datasets[0] + + # Validate consistency across all datasets + for i, ds in enumerate(datasets[1:], start=1): + if ds.shape != first.shape: + raise ValueError( + f"Dataset {i} shape {ds.shape} doesn't match first dataset shape {first.shape}" + ) + if not np.allclose(ds.sampling, first.sampling): + raise ValueError( + f"Dataset {i} sampling {ds.sampling} doesn't match first dataset" + ) + if ds.units != first.units: + raise ValueError( + f"Dataset {i} units {ds.units} doesn't match first dataset" + ) + + stacked = np.stack([d.array for d in datasets], axis=0) + + return cls.from_array( + array=stacked, + name=name or "5D-STEM dataset", + origin=np.concatenate([[0], first.origin]), + sampling=np.concatenate([[1], first.sampling]), + units=["index"] + list(first.units), + signal_units=first.signal_units, + stack_type=stack_type, + stack_values=stack_values, + ) + + # ------------------------------------------------------------------------- + # Stack operations + # ------------------------------------------------------------------------- + def stack_mean(self) -> Dataset4dstem: + """Average over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.mean, "mean") + + def stack_sum(self) -> Dataset4dstem: + """Sum over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.sum, "sum") + + def stack_max(self) -> Dataset4dstem: + """Maximum over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.max, "max") + + def stack_min(self) -> Dataset4dstem: + """Minimum over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.min, "min") + + def stack_std(self) -> Dataset4dstem: + """Standard deviation over the stack axis. Returns Dataset4dstem.""" + return self._reduce_stack(np.std, "std") + + def _reduce_stack(self, func, suffix: str) -> Dataset4dstem: + """Apply reduction function over stack axis.""" + return Dataset4dstem.from_array( + array=func(self.array, axis=0), + name=f"{self.name}_{suffix}", + origin=self.origin[1:], + sampling=self.sampling[1:], + units=self.units[1:], + signal_units=self.signal_units, + ) + + def _get_frame(self, idx: int) -> Dataset4dstem: + """Extract a single 4D frame from the stack.""" + if idx < 0: + idx = len(self) + idx + if idx < 0 or idx >= len(self): + raise IndexError(f"Frame index {idx} out of range for {len(self)} frames") + + frame = Dataset4dstem.from_array( + array=self.array[idx], + name=f"{self.name}_frame{idx}", + origin=self.origin[1:], + sampling=self.sampling[1:], + units=self.units[1:], + signal_units=self.signal_units, + ) + + # Inherit STEM metadata + frame.metadata["r_to_q_rotation_cw_deg"] = self.metadata.get("r_to_q_rotation_cw_deg") + frame.metadata["ellipticity"] = self.metadata.get("ellipticity") + + return frame + + # ------------------------------------------------------------------------- + # Virtual imaging + # ------------------------------------------------------------------------- + def get_virtual_image( + self, + mask: ArrayLike | None = None, + mode: str | None = None, + geometry: tuple | list | None = None, + name: str = "virtual_image", + attach: bool = True, + ) -> Dataset3d: + """Compute virtual image stack for all frames. + + Parameters + ---------- + mask : np.ndarray, optional + Custom mask matching diffraction pattern shape (k_row, k_col). + mode : str, optional + Mask mode: "circle" or "annular". + geometry : tuple or list, optional + Detector geometry in pixels. Format depends on mode: + + For "circle" mode: + - ``(None, radius)`` : Auto-fit center per frame. + - ``((cy, cx), radius)`` : Fixed center for all frames. + - ``[((cy0, cx0), r0), ...]`` : Per-frame geometry list. + + For "annular" mode: + - ``(None, (r_inner, r_outer))`` : Auto-fit center per frame. + - ``((cy, cx), (r_inner, r_outer))`` : Fixed center. + - ``[((cy0, cx0), (r0_in, r0_out)), ...]`` : Per-frame list. + + name : str, optional + Name for the virtual image. Default: "virtual_image". + attach : bool, optional + Store result in virtual_images dict. Default: True. + + Returns + ------- + Dataset3d + Virtual image stack with shape (n_frames, scan_row, scan_col). + + Notes + ----- + All geometry values are in pixels. To convert from mrad: + + >>> radius_px = radius_mrad / data.sampling[-1] + + Examples + -------- + Auto-fit center per frame: + + >>> bf = data.get_virtual_image(mode="circle", geometry=(None, 20)) + >>> adf = data.get_virtual_image(mode="annular", geometry=(None, (30, 80))) + + Fixed center for all frames: + + >>> k_center = (data.shape[-2] // 2, data.shape[-1] // 2) + >>> bf = data.get_virtual_image(mode="circle", geometry=(k_center, 20)) + + Per-frame geometry: + + >>> geometries = [(center, 20) for center in fitted_centers] + >>> bf = data.get_virtual_image(mode="circle", geometry=geometries) + """ + dp_shape = self.array.shape[-2:] + n_frames = len(self) + + if mask is not None: + if mask.shape != dp_shape: + raise ValueError(f"Mask shape {mask.shape} != diffraction pattern shape {dp_shape}") + virtual_stack = np.sum(self.array * mask, axis=(-1, -2)) + elif mode is not None: + # Full auto-detection: geometry=None + if geometry is None: + geometries = self._auto_fit_centers(mode, None) + virtual_stack = self._compute_per_frame_virtual(mode, geometries, dp_shape) + geometry = geometries + # Per-frame geometry list + elif isinstance(geometry, list): + if len(geometry) != n_frames: + raise ValueError( + f"geometry list length ({len(geometry)}) must match " + f"number of frames ({n_frames})" + ) + virtual_stack = self._compute_per_frame_virtual(mode, geometry, dp_shape) + else: + # Single geometry tuple: (center_or_none, radius_or_radii) + center, radius_or_radii = geometry + if center is None: + # Auto-fit center per frame (radius may also be None for auto) + geometries = self._auto_fit_centers(mode, radius_or_radii) + virtual_stack = self._compute_per_frame_virtual(mode, geometries, dp_shape) + geometry = geometries # Store fitted geometries + else: + # Fixed center for all frames + if mode == "circle": + final_mask = create_circle_mask(dp_shape, center, radius_or_radii) + elif mode == "annular": + final_mask = create_annular_mask(dp_shape, center, radius_or_radii) + else: + raise ValueError(f"Unknown mode '{mode}'. Use 'circle' or 'annular'.") + virtual_stack = np.sum(self.array * final_mask, axis=(-1, -2)) + else: + raise ValueError("Provide either mask or both mode and geometry") + + vi = Dataset3d.from_array( + array=virtual_stack, + name=name, + origin=self.origin[:3], + sampling=self.sampling[:3], + units=self.units[:3], + signal_units=self.signal_units, + ) + + if attach: + self._virtual_images[name] = vi + self._virtual_detectors[name] = { + "mask": mask.copy() if mask is not None else None, + "mode": mode, + "geometry": geometry, + } + + return vi + + def _auto_fit_centers(self, mode: str, radius_or_radii) -> list: + """Fit probe center for each frame and return list of geometries. + + If radius_or_radii is None, also auto-detect the radius. + """ + geometries = [] + for i in range(len(self)): + dp_mean = np.mean(self.array[i], axis=(0, 1)) + cy, cx, detected_radius = fit_probe_circle(dp_mean, show=False) + radius = detected_radius if radius_or_radii is None else radius_or_radii + geometries.append(((cy, cx), radius)) + return geometries + + def _compute_per_frame_virtual(self, mode: str, geometries: list, dp_shape: tuple) -> np.ndarray: + """Compute virtual images with per-frame geometry.""" + virtual_stack = np.zeros((len(self), self.shape[1], self.shape[2]), dtype=self.array.dtype) + for i, geom in enumerate(geometries): + center, radius_or_radii = geom + if mode == "circle": + mask = create_circle_mask(dp_shape, center, radius_or_radii) + elif mode == "annular": + mask = create_annular_mask(dp_shape, center, radius_or_radii) + else: + raise ValueError(f"Unknown mode '{mode}'") + virtual_stack[i] = np.sum(self.array[i] * mask, axis=(-1, -2)) + return virtual_stack + + def show_virtual_images(self, figsize: tuple[int, int] | None = None, **kwargs) -> tuple: + """Display all virtual images stored in the dataset. + + Parameters + ---------- + figsize : tuple[int, int] | None + Figure size. If None, auto-calculated. + **kwargs + Arguments passed to show_2d (cmap, norm, cbar, etc.) + + Returns + ------- + tuple + (fig, axs) from matplotlib. + """ + if not self.virtual_images: + print("No virtual images. Create with get_virtual_image().") + return None, None + + # Each virtual image is Dataset3d - show first frame of each + arrays = [vi.array[0] for vi in self.virtual_images.values()] + titles = [f"{name} (frame 0)" for name in self.virtual_images.keys()] + + n = len(arrays) + if figsize is None: + figsize = (4 * min(n, 4), 4 * ((n + 3) // 4)) + + return show_2d(arrays, title=titles, figax_size=figsize, **kwargs) + + def regenerate_virtual_images(self) -> None: + """Regenerate virtual images from stored detector information.""" + if not self._virtual_detectors: + return + + self._virtual_images.clear() + + for name, info in self._virtual_detectors.items(): + try: + if info["mode"] is not None and info["geometry"] is not None: + self.get_virtual_image( + mode=info["mode"], + geometry=info["geometry"], + name=name, + attach=True, + ) + else: + print(f"Warning: Cannot regenerate '{name}' - insufficient detector info.") + except Exception as e: + print(f"Warning: Failed to regenerate '{name}': {e}") + + def update_virtual_detector( + self, + name: str, + mask: np.ndarray | None = None, + mode: str | None = None, + geometry: tuple | None = None, + ) -> None: + """Update virtual detector and regenerate the corresponding image. + + Parameters + ---------- + name : str + Name of virtual detector to update. + mask : np.ndarray | None + New mask (must match DP dimensions). + mode : str | None + New mode ("circle" or "annular"). + geometry : tuple | None + New geometry. + """ + if name not in self._virtual_detectors: + raise ValueError(f"Detector '{name}' not found. Available: {list(self._virtual_detectors.keys())}") + + self._virtual_detectors[name]["mask"] = mask.copy() if mask is not None else None + self._virtual_detectors[name]["mode"] = mode + self._virtual_detectors[name]["geometry"] = geometry + + self.get_virtual_image(mask=mask, mode=mode, geometry=geometry, name=name, attach=True) + + def clear_virtual_images(self) -> None: + """Clear virtual images while keeping detector information.""" + self._virtual_images.clear() + + def clear_all_virtual_data(self) -> None: + """Clear both virtual images and detector information.""" + self._virtual_images.clear() + self._virtual_detectors.clear() + + def show(self, *args, **kwargs): + """Not implemented for 5D data. + + Raises + ------ + NotImplementedError + Always raised. Use alternative methods for visualization. + + See Also + -------- + show_virtual_images : Display virtual images. + data[i].show() : Show a single 4D frame. + data[i].dp_mean.show() : Show mean diffraction pattern. + """ + raise NotImplementedError( + "show() is not meaningful for 5D data. Use:\n" + " - data[i].show() for a single frame\n" + " - data[i].dp_mean.show() for mean DP\n" + " - data.show_virtual_images() for virtual images\n" + " - data.get_virtual_image(...).show() for virtual image stack" + ) + + # ------------------------------------------------------------------------- + # Copy + # ------------------------------------------------------------------------- + def _copy_custom_attributes(self, new_dataset) -> None: + """Copy Dataset5dstem-specific attributes.""" + super()._copy_custom_attributes(new_dataset) + new_dataset._stack_type = self._stack_type + new_dataset._stack_values = self._stack_values.copy() if self._stack_values is not None else None + new_dataset._virtual_images = {} + new_dataset._virtual_detectors = { + name: {"mask": None, "mode": info["mode"], "geometry": info["geometry"]} + for name, info in self._virtual_detectors.items() + } diff --git a/src/quantem/core/io/__init__.py b/src/quantem/core/io/__init__.py index 2780eae4..c6fa56a0 100644 --- a/src/quantem/core/io/__init__.py +++ b/src/quantem/core/io/__init__.py @@ -1,5 +1,6 @@ from quantem.core.io.file_readers import read_2d as read_2d from quantem.core.io.file_readers import read_4dstem as read_4dstem +from quantem.core.io.file_readers import read_5dstem as read_5dstem from quantem.core.io.file_readers import ( read_emdfile_to_4dstem as read_emdfile_to_4dstem, ) diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index cb36f1de..29762548 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -1,13 +1,16 @@ import importlib +import json from os import PathLike from pathlib import Path import h5py +import numpy as np from quantem.core.datastructures import Dataset as Dataset from quantem.core.datastructures import Dataset2d as Dataset2d from quantem.core.datastructures import Dataset3d as Dataset3d from quantem.core.datastructures import Dataset4dstem as Dataset4dstem +from quantem.core.datastructures import Dataset5dstem as Dataset5dstem def read_4dstem( @@ -93,6 +96,170 @@ def read_4dstem( return dataset +def read_5dstem( + file_path: str | PathLike, + file_type: str | None = None, + stack_type: str = "auto", + **kwargs, +) -> Dataset5dstem: + """ + File reader for 5D-STEM data. + + Supports: + - Nion Swift h5 files (auto-detected from 'properties' attribute) + - rosettasciio formats with 5D data + + Parameters + ---------- + file_path : str | PathLike + Path to data file (.h5, .hdf5, or rosettasciio-supported formats). + file_type : str | None, optional + File reader type (e.g., "hdf5", "emd"). If None, auto-detects from extension. + stack_type : str, optional + Type of stack dimension. Options: "time", "tilt", "energy", "dose", "focus", "generic". + Default "auto" detects from metadata (Nion Swift is_sequence=True -> "time"). + **kwargs : dict + Additional keyword arguments passed to Dataset5dstem constructor. + + Returns + ------- + Dataset5dstem + 5D dataset with shape (stack, scan_row, scan_col, k_row, k_col). + + Examples + -------- + Load Nion Swift time series (auto-detects stack_type from is_sequence): + + >>> data = read_5dstem("time_series.h5") + >>> data.shape + (10, 512, 512, 12, 12) + >>> data.stack_type + 'time' + + Load as tilt series (override auto-detection): + + >>> data = read_5dstem("tilt_series.h5", stack_type="tilt") + >>> data.stack_type + 'tilt' + + Access calibrations extracted from file metadata: + + >>> data.sampling # [stack, scan_row, scan_col, k_row, k_col] + array([1.0, 0.5, 0.5, 0.006, 0.006]) + >>> data.units + ['pixels', 'nm', 'nm', 'rad', 'rad'] + """ + file_path = Path(file_path) + + # Try Nion Swift h5 format first + if file_path.suffix.lower() in [".h5", ".hdf5"]: + try: + with h5py.File(file_path, "r") as f: + if "data" in f and "properties" in f["data"].attrs: + return _read_nion_swift_5dstem(file_path, stack_type, **kwargs) + except Exception: + pass # Fall through to rsciio + + # Fall back to rosettasciio + if file_type is None: + file_type = file_path.suffix.lower().lstrip(".") + file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader + data_list = file_reader(file_path) + + # Find first 5D dataset + five_d_datasets = [(i, d) for i, d in enumerate(data_list) if d["data"].ndim == 5] + if len(five_d_datasets) == 0: + print(f"No 5D datasets found in {file_path}. Available datasets:") + for i, d in enumerate(data_list): + print(f" Dataset {i}: shape {d['data'].shape}, ndim={d['data'].ndim}") + raise ValueError("No 5D dataset found in file") + dataset_index, imported_data = five_d_datasets[0] + if len(data_list) > 1: + print( + f"File contains {len(data_list)} dataset(s). Using dataset {dataset_index} " + f"with shape {imported_data['data'].shape}" + ) + + # Extract calibrations + imported_axes = imported_data["axes"] + sampling = kwargs.pop("sampling", [ax["scale"] for ax in imported_axes]) + origin = kwargs.pop("origin", [ax["offset"] for ax in imported_axes]) + units = kwargs.pop( + "units", + ["pixels" if ax["units"] == "1" else ax["units"] for ax in imported_axes], + ) + if stack_type == "auto": + stack_type = "generic" + + return Dataset5dstem.from_array( + array=imported_data["data"], + sampling=sampling, + origin=origin, + units=units, + stack_type=stack_type, + **kwargs, + ) + + +def _read_nion_swift_5dstem( + file_path: str | PathLike, + stack_type: str = "auto", + **kwargs, +) -> Dataset5dstem: + """ + Read Nion Swift 5D-STEM h5 file. + + Nion Swift stores data with: + - f['data'] containing the array + - f['data'].attrs['properties'] containing JSON metadata + + Parameters + ---------- + file_path : str | PathLike + Path to Nion Swift h5 file + stack_type : str, optional + Stack type or "auto" to detect from metadata + + Returns + ------- + Dataset5dstem + """ + with h5py.File(file_path, "r") as f: + data = f["data"][:] + props = json.loads(f["data"].attrs["properties"]) + if data.ndim != 5: + raise ValueError(f"Expected 5D data, got {data.ndim}D with shape {data.shape}") + + # Extract calibrations + cals = props.get("dimensional_calibrations", []) + if len(cals) == 5: + origin = np.array([c.get("offset", 0.0) for c in cals]) + sampling = np.array([c.get("scale", 1.0) for c in cals]) + units = [c.get("units", "") or "pixels" for c in cals] + else: + origin = np.zeros(5) + sampling = np.ones(5) + units = ["pixels"] * 5 + + # Determine stack type from metadata + if stack_type == "auto": + stack_type = "time" if props.get("is_sequence", False) else "generic" + + # Get intensity calibration + intensity_cal = props.get("intensity_calibration", {}) + signal_units = intensity_cal.get("units", "arb. units") or "arb. units" + + return Dataset5dstem.from_array( + array=data, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + stack_type=stack_type, + **kwargs, + ) + + def read_2d( file_path: str | PathLike, file_type: str | None = None, diff --git a/src/quantem/core/utils/masks.py b/src/quantem/core/utils/masks.py new file mode 100644 index 00000000..0b884a6b --- /dev/null +++ b/src/quantem/core/utils/masks.py @@ -0,0 +1,64 @@ +import numpy as np + + +def create_circle_mask( + shape: tuple[int, int], + center: tuple[float, float], + radius: float, +) -> np.ndarray: + """ + Create a circular mask for virtual image formation. + + Parameters + ---------- + shape : tuple[int, int] + Shape of the mask (rows, cols) + center : tuple[float, float] + Center coordinates (cy, cx) of the circle + radius : float + Radius of the circle + + Returns + ------- + np.ndarray + Boolean mask with True inside the circle + """ + cy, cx = center + y, x = np.ogrid[: shape[0], : shape[1]] + + # Calculate distance from center + distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) + + return distance <= radius + + +def create_annular_mask( + shape: tuple[int, int], + center: tuple[float, float], + radii: tuple[float, float], +) -> np.ndarray: + """ + Create an annular (ring-shaped) mask for virtual image formation. + + Parameters + ---------- + shape : tuple[int, int] + Shape of the mask (rows, cols) + center : tuple[float, float] + Center coordinates (cy, cx) of the annulus + radii : tuple[float, float] + Inner and outer radii (r_inner, r_outer) of the annulus + + Returns + ------- + np.ndarray + Boolean mask with True inside the annular region + """ + cy, cx = center + r_inner, r_outer = radii + y, x = np.ogrid[: shape[0], : shape[1]] + + # Calculate distance from center + distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) + + return (distance >= r_inner) & (distance <= r_outer) diff --git a/src/quantem/core/utils/validators.py b/src/quantem/core/utils/validators.py index 02f8a935..fe57110a 100644 --- a/src/quantem/core/utils/validators.py +++ b/src/quantem/core/utils/validators.py @@ -49,6 +49,14 @@ def ensure_valid_array( TypeError If the input could not be converted to a NumPy array """ + # Check for torch GPU tensors and provide helpful error + if config.get("has_torch"): + if isinstance(array, torch.Tensor) and array.is_cuda: + raise TypeError( + f"Cannot convert torch GPU tensor (device={array.device}) to numpy array. " + "Use .cpu() to move to CPU first: Dataset.from_array(tensor.cpu())" + ) + is_cupy = False if config.get("has_cupy"): if isinstance(array, cp.ndarray): diff --git a/tests/conftest.py b/tests/conftest.py index 3178a8fd..7cdcf3c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,67 @@ +import json +from pathlib import Path + +import h5py +import numpy as np import pytest +def create_mock_nion_swift_h5(path: Path, shape=(4, 6, 7, 3, 5), is_sequence=True): + """Create a mock Nion Swift h5 file with realistic metadata structure. + + Parameters + ---------- + path : Path + Output file path. + shape : tuple + Data shape (frames, scan_row, scan_col, k_row, k_col). + is_sequence : bool + Whether to set is_sequence=True in properties. + """ + properties = { + "type": "data-item", + "uuid": "test-uuid", + "is_sequence": is_sequence, + "intensity_calibration": {"offset": 0.0, "scale": 1.0, "units": "counts" if is_sequence else ""}, + "dimensional_calibrations": [ + {"offset": 0.0, "scale": 1.0, "units": ""}, # frames + {"offset": -1.0, "scale": 0.5, "units": "nm"}, # scan_row + {"offset": -1.5, "scale": 0.5, "units": "nm"}, # scan_col + {"offset": -0.036, "scale": 0.006, "units": "rad"}, # k_row + {"offset": -0.036, "scale": 0.006, "units": "rad"}, # k_col + ], + "collection_dimension_count": 2, + "datum_dimension_count": 2, + "metadata": { + "instrument": { + "high_tension": 60000.0, # 60 keV + "defocus": 0.0, + "ImageScanned": { + "probe_ha": 0.035, # 35 mrad half-angle + "C10": 0.0, + "C30": 0.0, + }, + }, + "scan": { + "scan_size": [shape[1], shape[2]], + }, + }, + } + + with h5py.File(path, "w") as f: + data = np.random.rand(*shape).astype(np.float32) + dset = f.create_dataset("data", data=data) + dset.attrs["properties"] = json.dumps(properties) + + +@pytest.fixture +def mock_nion_5dstem_file(tmp_path): + """Create a temporary mock Nion Swift 5D-STEM h5 file.""" + path = tmp_path / "mock_nion_5dstem.h5" + create_mock_nion_swift_h5(path, shape=(4, 6, 7, 3, 5)) + return path + + def pytest_addoption(parser): parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") diff --git a/tests/datastructures/test_dataset5dstem.py b/tests/datastructures/test_dataset5dstem.py new file mode 100644 index 00000000..e5ffcec8 --- /dev/null +++ b/tests/datastructures/test_dataset5dstem.py @@ -0,0 +1,160 @@ +"""Tests for Dataset5dstem class.""" +import numpy as np +import pytest +from quantem.core.datastructures.dataset5dstem import Dataset5dstem +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.datastructures.dataset3d import Dataset3d + + +@pytest.fixture +def sample_dataset(): + """Create a sample 5D-STEM dataset with distinct sizes for clarity.""" + array = np.random.rand(4, 6, 7, 3, 5) # (frames, scan_row, scan_col, k_row, k_col) + return Dataset5dstem.from_array(array=array, stack_type="time") + + +class TestDataset5dstem: + """Core Dataset5dstem tests.""" + + def test_from_array(self): + """Test creating Dataset5dstem from array.""" + array = np.random.rand(4, 6, 7, 3, 5) + data = Dataset5dstem.from_array(array=array, stack_type="tilt") + assert data.shape == (4, 6, 7, 3, 5) + assert data.stack_type == "tilt" + assert len(data) == 4 + + def test_from_4dstem(self): + """Test creating Dataset5dstem from list of Dataset4dstem.""" + datasets = [Dataset4dstem.from_array(np.random.rand(6, 7, 3, 5)) for _ in range(4)] + data = Dataset5dstem.from_4dstem(datasets, stack_type="tilt") + assert data.shape == (4, 6, 7, 3, 5) + assert data.stack_type == "tilt" + + def test_indexing_and_iteration(self, sample_dataset): + """Test data[i] and iteration return Dataset4dstem frames.""" + frame = sample_dataset[0] # -> Dataset4dstem + assert isinstance(frame, Dataset4dstem) + assert frame.shape == (6, 7, 3, 5) + frames = list(sample_dataset) + assert len(frames) == 4 + assert all(isinstance(f, Dataset4dstem) for f in frames) + + def test_stack_reductions(self, sample_dataset): + """Test stack reduction methods return Dataset4dstem with correct values.""" + mean = sample_dataset.stack_mean() # -> Dataset4dstem + assert isinstance(mean, Dataset4dstem) + assert mean.shape == (6, 7, 3, 5) + arr = sample_dataset.array + assert np.allclose(sample_dataset.stack_mean().array, np.mean(arr, axis=0)) + assert np.allclose(sample_dataset.stack_sum().array, np.sum(arr, axis=0)) + assert np.allclose(sample_dataset.stack_max().array, np.max(arr, axis=0)) + assert np.allclose(sample_dataset.stack_min().array, np.min(arr, axis=0)) + assert np.allclose(sample_dataset.stack_std().array, np.std(arr, axis=0)) + + def test_slicing(self, sample_dataset): + """Test common slicing patterns.""" + substack = sample_dataset[1:3] # -> Dataset5dstem (2, 6, 7, 3, 5) + assert isinstance(substack, Dataset5dstem) + assert substack.shape == (2, 6, 7, 3, 5) + assert substack.stack_type == "time" # metadata preserved + position = sample_dataset[:, 2, 3] # -> Dataset3d (4, 3, 5) + assert isinstance(position, Dataset3d) + assert position.shape == (4, 3, 5) + cropped = sample_dataset[:, :, :, 1:3, 1:4] # -> Dataset5dstem (4, 6, 7, 2, 3) + assert isinstance(cropped, Dataset5dstem) + assert cropped.shape == (4, 6, 7, 2, 3) + + def test_get_virtual_image(self, sample_dataset): + """Test virtual image creation returns Dataset3d stack.""" + vi = sample_dataset.get_virtual_image(mode="circle", geometry=((1, 2), 1), name="bf") + assert isinstance(vi, Dataset3d) + assert vi.shape == (4, 6, 7) + assert "bf" in sample_dataset.virtual_images + + def test_validation(self): + """Test validation errors.""" + array = np.random.rand(4, 6, 7, 3, 5) + with pytest.raises(ValueError, match="stack_values length"): + Dataset5dstem.from_array(array=array, stack_values=np.array([1, 2])) + ds1 = Dataset4dstem.from_array(np.random.rand(6, 7, 3, 5)) + ds2 = Dataset4dstem.from_array(np.random.rand(8, 9, 3, 5)) + with pytest.raises(ValueError, match="shape"): + Dataset5dstem.from_4dstem([ds1, ds2]) + + def test_virtual_image_management(self, sample_dataset): + """Test virtual image clear, regenerate, update methods.""" + # Create virtual images + sample_dataset.get_virtual_image(mode="circle", geometry=((1, 2), 1), name="bf") + sample_dataset.get_virtual_image(mode="annular", geometry=((1, 2), (1, 2)), name="adf") + assert len(sample_dataset.virtual_images) == 2 + + # Clear images only + sample_dataset.clear_virtual_images() + assert len(sample_dataset.virtual_images) == 0 + assert len(sample_dataset.virtual_detectors) == 2 + + # Regenerate + sample_dataset.regenerate_virtual_images() + assert len(sample_dataset.virtual_images) == 2 + + # Update detector + sample_dataset.update_virtual_detector("bf", mode="circle", geometry=((1, 2), 2)) + assert sample_dataset.virtual_detectors["bf"]["geometry"] == ((1, 2), 2) + + # Clear all + sample_dataset.clear_all_virtual_data() + assert len(sample_dataset.virtual_images) == 0 + assert len(sample_dataset.virtual_detectors) == 0 + + def test_virtual_image_per_frame_geometry(self, sample_dataset): + """Test virtual image with per-frame geometry.""" + geometries = [((1, 2), 1) for _ in range(4)] + vi = sample_dataset.get_virtual_image(mode="circle", geometry=geometries, name="bf_perframe") + assert vi.shape == (4, 6, 7) + assert len(sample_dataset.virtual_detectors["bf_perframe"]["geometry"]) == 4 + + def test_virtual_image_auto_fit_center(self): + """Test virtual image with auto-fit center (requires realistic DP).""" + # Create data with a clear probe pattern + n_frames, scan_r, scan_c, k_r, k_c = 3, 4, 4, 32, 32 + array = np.zeros((n_frames, scan_r, scan_c, k_r, k_c), dtype=np.float32) + # Add circular probe at center + cy, cx = k_r // 2, k_c // 2 + y, x = np.ogrid[:k_r, :k_c] + mask = (y - cy) ** 2 + (x - cx) ** 2 <= 8 ** 2 + array[:, :, :, mask] = 1.0 + + data = Dataset5dstem.from_array(array, stack_type="time") + vi = data.get_virtual_image(mode="circle", geometry=(None, 5), name="bf_auto") + assert vi.shape == (3, 4, 4) + assert isinstance(data.virtual_detectors["bf_auto"]["geometry"], list) + assert len(data.virtual_detectors["bf_auto"]["geometry"]) == 3 + + def test_virtual_image_full_auto(self): + """Test virtual image with full auto-detection (center and radius).""" + # Create data with a clear probe pattern + n_frames, scan_r, scan_c, k_r, k_c = 3, 4, 4, 32, 32 + array = np.zeros((n_frames, scan_r, scan_c, k_r, k_c), dtype=np.float32) + # Add circular probe at center with radius 8 + cy, cx = k_r // 2, k_c // 2 + y, x = np.ogrid[:k_r, :k_c] + mask = (y - cy) ** 2 + (x - cx) ** 2 <= 8 ** 2 + array[:, :, :, mask] = 1.0 + + data = Dataset5dstem.from_array(array, stack_type="time") + # geometry=None triggers full auto-detection + vi = data.get_virtual_image(mode="circle", geometry=None, name="bf_full_auto") + assert vi.shape == (3, 4, 4) + geoms = data.virtual_detectors["bf_full_auto"]["geometry"] + assert isinstance(geoms, list) + assert len(geoms) == 3 + # Check that radius was auto-detected (should be close to 8) + for center, radius in geoms: + assert 6 < radius < 10 # Allow some tolerance + + def test_show_raises_error(self, sample_dataset): + """Test that show() raises NotImplementedError with helpful message.""" + with pytest.raises(NotImplementedError, match="not meaningful for 5D"): + sample_dataset.show() + diff --git a/tests/io/test_read_5dstem.py b/tests/io/test_read_5dstem.py new file mode 100644 index 00000000..05e4cfc9 --- /dev/null +++ b/tests/io/test_read_5dstem.py @@ -0,0 +1,44 @@ +"""Tests for read_5dstem with mocked Nion Swift metadata.""" +import json +import numpy as np +import pytest +import h5py +from quantem.core.io import read_5dstem +from quantem.core.datastructures import Dataset5dstem + + +class TestRead5dstemNionSwift: + """Tests for reading Nion Swift 5D-STEM data.""" + + def test_read_nion_5dstem(self, mock_nion_5dstem_file): + """Test reading Nion Swift file extracts shape, type, and calibrations.""" + data = read_5dstem(mock_nion_5dstem_file) + assert isinstance(data, Dataset5dstem) + assert data.shape == (4, 6, 7, 3, 5) + assert data.stack_type == "time" # is_sequence=True + assert np.allclose(data.sampling, [1.0, 0.5, 0.5, 0.006, 0.006]) + assert np.allclose(data.origin[:3], [0.0, -1.0, -1.5]) + assert data.units[1:3] == ["nm", "nm"] + assert data.units[3:] == ["rad", "rad"] + assert data.signal_units == "counts" + + def test_read_override_stack_type(self, mock_nion_5dstem_file): + """Test stack_type can be overridden.""" + data = read_5dstem(mock_nion_5dstem_file, stack_type="tilt") + assert data.stack_type == "tilt" + + def test_read_no_sequence(self, tmp_path): + """Test is_sequence=False -> stack_type='generic', empty units -> 'arb. units'.""" + path = tmp_path / "no_seq.h5" + properties = { + "is_sequence": False, + "dimensional_calibrations": [{"offset": 0, "scale": 1, "units": ""}] * 5, + "intensity_calibration": {"units": ""}, + "metadata": {}, + } + with h5py.File(path, "w") as f: + dset = f.create_dataset("data", data=np.zeros((2, 3, 3, 4, 4), dtype=np.float32)) + dset.attrs["properties"] = json.dumps(properties) + data = read_5dstem(path) + assert data.stack_type == "generic" + assert data.signal_units == "arb. units"