diff --git a/src/quantem/core/datastructures/dataset3d.py b/src/quantem/core/datastructures/dataset3d.py index 53d6d02c..059c7456 100644 --- a/src/quantem/core/datastructures/dataset3d.py +++ b/src/quantem/core/datastructures/dataset3d.py @@ -5,6 +5,7 @@ from quantem.core.datastructures.dataset import Dataset from quantem.core.utils.validators import ensure_valid_array +from quantem.core.visualization import show_2d from quantem.core.visualization.visualization_utils import ScalebarConfig @@ -68,6 +69,50 @@ def from_array( units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", ) -> Self: + """Create a Dataset3d from a 3D array. + + Parameters + ---------- + array : NDArray | Any + 3D array with shape (n_frames, height, width) + name : str | None + Dataset name. Default: "3D dataset" + origin : NDArray | tuple | list | float | int | None + Origin for each dimension. Default: [0, 0, 0] + sampling : NDArray | tuple | list | float | int | None + Sampling for each dimension. Default: [1, 1, 1] + units : list[str] | tuple | list | None + Units for each dimension. Default: ["index", "pixels", "pixels"] + signal_units : str + Units for array values. Default: "arb. units" + + Returns + ------- + Dataset3d + + Examples + -------- + >>> import numpy as np + >>> from quantem.core.datastructures import Dataset3d + >>> arr = np.random.rand(10, 64, 64) + >>> data = Dataset3d.from_array(arr) + >>> data.shape + (10, 64, 64) + + With calibration: + + >>> data = Dataset3d.from_array( + ... arr, + ... sampling=[1, 0.1, 0.1], + ... units=["frame", "nm", "nm"], + ... ) + + Visualize: + + >>> data.show() # all frames in grid + >>> data.show(index=0) # single frame + >>> data.show(ncols=2) # 2 columns + """ array = ensure_valid_array(array, ndim=3) return cls( array=array, @@ -90,7 +135,37 @@ def from_shape( units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", ) -> Self: - """Create a new Dataset3d filled with a constant value.""" + """Create a Dataset3d filled with a constant value. + + Parameters + ---------- + shape : tuple[int, int, int] + Shape (n_frames, height, width) + name : str + Dataset name. Default: "constant 3D dataset" + fill_value : float + Value to fill array with. Default: 0.0 + origin : NDArray | tuple | list | float | int | None + Origin for each dimension + sampling : NDArray | tuple | list | float | int | None + Sampling for each dimension + units : list[str] | tuple | list | None + Units for each dimension + signal_units : str + Units for array values + + Returns + ------- + Dataset3d + + Examples + -------- + >>> data = Dataset3d.from_shape((10, 64, 64)) + >>> data.shape + (10, 64, 64) + >>> data.array.max() + 0.0 + """ array = np.full(shape, fill_value, dtype=np.float32) return cls.from_array( array=array, @@ -107,24 +182,68 @@ def to_dataset2d(self): def show( self, - index: int = 0, + index: int | None = None, scalebar: ScalebarConfig | bool = True, title: str | None = None, + suptitle: str | None = None, + ncols: int = 4, + returnfig: bool = False, **kwargs, ): """ - Display a 2D slice of the 3D dataset. + Display 2D slices of the 3D dataset. Parameters ---------- - index : int - Index of the 2D slice to display (along axis 0). - scalebar: ScalebarConfig or bool - If True, displays scalebar - title: str - Title of Dataset - **kwargs: dict - Keyword arguments for show_2d - """ + index : int | None + Index of the 2D slice to display. If None, shows all slices in a grid. + scalebar : ScalebarConfig or bool + If True, displays scalebar. + title : str | None + Title for the plot. If None, uses "Frame 0", "Frame 1", etc. + suptitle : str | None + Figure super title displayed above all subplots. + ncols : int + Maximum columns when showing all slices. Default: 4. + returnfig : bool + If True, returns (fig, axes). Default: False. + **kwargs : dict + Keyword arguments for show_2d (cmap, cbar, norm, etc.). - return self[index].show(scalebar=scalebar, title=title, **kwargs) + Examples + -------- + >>> data.show() # show all frames in grid + >>> data.show(index=0) # show single frame + >>> data.show(ncols=3) # 3 columns + >>> data.show(suptitle="Diffraction patterns") # with super title + >>> fig, axes = data.show(returnfig=True) # get figure for customization + """ + if index is not None: + # Handle negative index + actual_index = index if index >= 0 else self.shape[0] + index + default_title = title if title is not None else f"Frame {actual_index}" + result = self[index].show(scalebar=scalebar, title=default_title, **kwargs) + return result if returnfig else None + # Show all frames in a grid + n = self.shape[0] + nrows = (n + ncols - 1) // ncols + arrays = [] + titles = [] + for row in range(nrows): + row_arrays = [] + row_titles = [] + for col in range(ncols): + i = row * ncols + col + if i < n: + row_arrays.append(self.array[i]) + row_titles.append(f"Frame {i}" if title is None else f"{title} {i}") + else: + row_arrays.append(np.zeros_like(self.array[0])) + row_titles.append("") + arrays.append(row_arrays) + titles.append(row_titles) + fig, axes = show_2d(arrays, scalebar=scalebar, title=titles, **kwargs) + if suptitle is not None: + fig.suptitle(suptitle, fontsize=14) + fig.subplots_adjust(top=0.92) + return (fig, axes) if returnfig else None