Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 132 additions & 13 deletions src/quantem/core/datastructures/dataset3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from quantem.core.datastructures.dataset import Dataset
from quantem.core.utils.validators import ensure_valid_array
from quantem.core.visualization import show_2d
from quantem.core.visualization.visualization_utils import ScalebarConfig


Expand Down Expand Up @@ -68,6 +69,50 @@ def from_array(
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
) -> Self:
"""Create a Dataset3d from a 3D array.

Parameters
----------
array : NDArray | Any
3D array with shape (n_frames, height, width)
name : str | None
Dataset name. Default: "3D dataset"
origin : NDArray | tuple | list | float | int | None
Origin for each dimension. Default: [0, 0, 0]
sampling : NDArray | tuple | list | float | int | None
Sampling for each dimension. Default: [1, 1, 1]
units : list[str] | tuple | list | None
Units for each dimension. Default: ["index", "pixels", "pixels"]
signal_units : str
Units for array values. Default: "arb. units"

Returns
-------
Dataset3d

Examples
--------
>>> import numpy as np
>>> from quantem.core.datastructures import Dataset3d
>>> arr = np.random.rand(10, 64, 64)
>>> data = Dataset3d.from_array(arr)
>>> data.shape
(10, 64, 64)

With calibration:

>>> data = Dataset3d.from_array(
... arr,
... sampling=[1, 0.1, 0.1],
... units=["frame", "nm", "nm"],
... )

Visualize:

>>> data.show() # all frames in grid
>>> data.show(index=0) # single frame
>>> data.show(ncols=2) # 2 columns
"""
array = ensure_valid_array(array, ndim=3)
return cls(
array=array,
Expand All @@ -90,7 +135,37 @@ def from_shape(
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
) -> Self:
"""Create a new Dataset3d filled with a constant value."""
"""Create a Dataset3d filled with a constant value.

Parameters
----------
shape : tuple[int, int, int]
Shape (n_frames, height, width)
name : str
Dataset name. Default: "constant 3D dataset"
fill_value : float
Value to fill array with. Default: 0.0
origin : NDArray | tuple | list | float | int | None
Origin for each dimension
sampling : NDArray | tuple | list | float | int | None
Sampling for each dimension
units : list[str] | tuple | list | None
Units for each dimension
signal_units : str
Units for array values

Returns
-------
Dataset3d

Examples
--------
>>> data = Dataset3d.from_shape((10, 64, 64))
>>> data.shape
(10, 64, 64)
>>> data.array.max()
0.0
"""
array = np.full(shape, fill_value, dtype=np.float32)
return cls.from_array(
array=array,
Expand All @@ -107,24 +182,68 @@ def to_dataset2d(self):

def show(
self,
index: int = 0,
index: int | None = None,
scalebar: ScalebarConfig | bool = True,
title: str | None = None,
suptitle: str | None = None,
ncols: int = 4,
returnfig: bool = False,
**kwargs,
):
"""
Display a 2D slice of the 3D dataset.
Display 2D slices of the 3D dataset.

Parameters
----------
index : int
Index of the 2D slice to display (along axis 0).
scalebar: ScalebarConfig or bool
If True, displays scalebar
title: str
Title of Dataset
**kwargs: dict
Keyword arguments for show_2d
"""
index : int | None
Index of the 2D slice to display. If None, shows all slices in a grid.
scalebar : ScalebarConfig or bool
If True, displays scalebar.
title : str | None
Title for the plot. If None, uses "Frame 0", "Frame 1", etc.
suptitle : str | None
Figure super title displayed above all subplots.
ncols : int
Maximum columns when showing all slices. Default: 4.
returnfig : bool
If True, returns (fig, axes). Default: False.
**kwargs : dict
Keyword arguments for show_2d (cmap, cbar, norm, etc.).

return self[index].show(scalebar=scalebar, title=title, **kwargs)
Examples
--------
>>> data.show() # show all frames in grid
>>> data.show(index=0) # show single frame
>>> data.show(ncols=3) # 3 columns
>>> data.show(suptitle="Diffraction patterns") # with super title
>>> fig, axes = data.show(returnfig=True) # get figure for customization
"""
if index is not None:
# Handle negative index
actual_index = index if index >= 0 else self.shape[0] + index
default_title = title if title is not None else f"Frame {actual_index}"
result = self[index].show(scalebar=scalebar, title=default_title, **kwargs)
return result if returnfig else None
# Show all frames in a grid
n = self.shape[0]
nrows = (n + ncols - 1) // ncols
arrays = []
titles = []
for row in range(nrows):
row_arrays = []
row_titles = []
for col in range(ncols):
i = row * ncols + col
if i < n:
row_arrays.append(self.array[i])
row_titles.append(f"Frame {i}" if title is None else f"{title} {i}")
else:
row_arrays.append(np.zeros_like(self.array[0]))
row_titles.append("")
arrays.append(row_arrays)
titles.append(row_titles)
fig, axes = show_2d(arrays, scalebar=scalebar, title=titles, **kwargs)
if suptitle is not None:
fig.suptitle(suptitle, fontsize=14)
fig.subplots_adjust(top=0.92)
return (fig, axes) if returnfig else None