Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,4 @@ ipynb-playground/
# widget (JS build artifacts)
node_modules/
widget/src/quantem/widget/static/
*.ipynb
8 changes: 5 additions & 3 deletions src/quantem/core/datastructures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from quantem.core.datastructures.dataset import Dataset as Dataset
from quantem.core.datastructures.vector import Vector as Vector

from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem
from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d
from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d
from quantem.core.datastructures.dataset2d import Dataset2d as Dataset2d
from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d
from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d
from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem
from quantem.core.datastructures.dataset5d import Dataset5d as Dataset5d
from quantem.core.datastructures.dataset5dstem import Dataset5dstem as Dataset5dstem
146 changes: 134 additions & 12 deletions src/quantem/core/datastructures/dataset3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,50 @@ def from_array(
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
) -> Self:
"""Create a Dataset3d from a 3D array.

Parameters
----------
array : NDArray | Any
3D array with shape (n_frames, height, width)
name : str | None
Dataset name. Default: "3D dataset"
origin : NDArray | tuple | list | float | int | None
Origin for each dimension. Default: [0, 0, 0]
sampling : NDArray | tuple | list | float | int | None
Sampling for each dimension. Default: [1, 1, 1]
units : list[str] | tuple | list | None
Units for each dimension. Default: ["index", "pixels", "pixels"]
signal_units : str
Units for array values. Default: "arb. units"

Returns
-------
Dataset3d

Examples
--------
>>> import numpy as np
>>> from quantem.core.datastructures import Dataset3d
>>> arr = np.random.rand(10, 64, 64)
>>> data = Dataset3d.from_array(arr)
>>> data.shape
(10, 64, 64)

With calibration:

>>> data = Dataset3d.from_array(
... arr,
... sampling=[1, 0.1, 0.1],
... units=["frame", "nm", "nm"],
... )

Visualize:

>>> data.show() # all frames in grid
>>> data.show(index=0) # single frame
>>> data.show(ncols=2) # 2 columns
"""
array = ensure_valid_array(array, ndim=3)
return cls(
array=array,
Expand All @@ -90,7 +134,37 @@ def from_shape(
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
) -> Self:
"""Create a new Dataset3d filled with a constant value."""
"""Create a Dataset3d filled with a constant value.

Parameters
----------
shape : tuple[int, int, int]
Shape (n_frames, height, width)
name : str
Dataset name. Default: "constant 3D dataset"
fill_value : float
Value to fill array with. Default: 0.0
origin : NDArray | tuple | list | float | int | None
Origin for each dimension
sampling : NDArray | tuple | list | float | int | None
Sampling for each dimension
units : list[str] | tuple | list | None
Units for each dimension
signal_units : str
Units for array values

Returns
-------
Dataset3d

Examples
--------
>>> data = Dataset3d.from_shape((10, 64, 64))
>>> data.shape
(10, 64, 64)
>>> data.array.max()
0.0
"""
array = np.full(shape, fill_value, dtype=np.float32)
return cls.from_array(
array=array,
Expand All @@ -107,24 +181,72 @@ def to_dataset2d(self):

def show(
self,
index: int = 0,
index: int | None = None,
scalebar: ScalebarConfig | bool = True,
title: str | None = None,
suptitle: str | None = None,
ncols: int = 4,
returnfig: bool = False,
**kwargs,
):
"""
Display a 2D slice of the 3D dataset.
Display 2D slices of the 3D dataset.

Parameters
----------
index : int
Index of the 2D slice to display (along axis 0).
scalebar: ScalebarConfig or bool
If True, displays scalebar
title: str
Title of Dataset
**kwargs: dict
Keyword arguments for show_2d
index : int | None
Index of the 2D slice to display. If None, shows all slices in a grid.
scalebar : ScalebarConfig or bool
If True, displays scalebar.
title : str | None
Title for the plot. If None, uses "Frame 0", "Frame 1", etc.
suptitle : str | None
Figure super title displayed above all subplots.
ncols : int
Maximum columns when showing all slices. Default: 4.
returnfig : bool
If True, returns (fig, axes). Default: False.
**kwargs : dict
Keyword arguments for show_2d (cmap, cbar, norm, etc.).

Examples
--------
>>> data.show() # show all frames in grid
>>> data.show(index=0) # show single frame
>>> data.show(ncols=3) # 3 columns
>>> data.show(suptitle="Diffraction patterns") # with super title
>>> fig, axes = data.show(returnfig=True) # get figure for customization
"""
from quantem.core.visualization import show_2d

if index is not None:
# Handle negative index
actual_index = index if index >= 0 else self.shape[0] + index
default_title = title if title is not None else f"Frame {actual_index}"
result = self[index].show(scalebar=scalebar, title=default_title, **kwargs)
return result if returnfig else None

# Show all frames in a grid
n = self.shape[0]
nrows = (n + ncols - 1) // ncols
arrays = []
titles = []
for row in range(nrows):
row_arrays = []
row_titles = []
for col in range(ncols):
i = row * ncols + col
if i < n:
row_arrays.append(self.array[i])
row_titles.append(f"Frame {i}" if title is None else f"{title} {i}")
else:
row_arrays.append(np.zeros_like(self.array[0]))
row_titles.append("")
arrays.append(row_arrays)
titles.append(row_titles)

return self[index].show(scalebar=scalebar, title=title, **kwargs)
fig, axes = show_2d(arrays, scalebar=scalebar, title=titles, **kwargs)
if suptitle is not None:
fig.suptitle(suptitle, fontsize=14)
fig.subplots_adjust(top=0.92)
return (fig, axes) if returnfig else None
72 changes: 16 additions & 56 deletions src/quantem/core/datastructures/dataset4dstem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from quantem.core.datastructures.dataset2d import Dataset2d
from quantem.core.datastructures.dataset4d import Dataset4d
from quantem.core.utils.masks import create_annular_mask, create_circle_mask
from quantem.core.utils.validators import ensure_valid_array
from quantem.core.visualization import show_2d
from quantem.core.visualization.visualization_utils import ScalebarConfig
Expand All @@ -16,7 +17,7 @@ class Dataset4dstem(Dataset4d):
"""A 4D-STEM dataset class that inherits from Dataset4d.

This class represents a 4D scanning transmission electron microscopy (STEM) dataset,
where the data consists of a 4D array with dimensions (scan_y, scan_x, dp_y, dp_x).
where the data consists of a 4D array with dimensions (scan_row, scan_col, k_row, k_col).
The first two dimensions represent real space scanning positions, while the latter
two dimensions represent reciprocal space diffraction patterns.

Expand Down Expand Up @@ -90,6 +91,17 @@ def __init__(
self._virtual_images = {}
self._virtual_detectors = {} # Store detector information for regeneration

def __repr__(self) -> str:
return f"Dataset4dstem(shape={self.shape}, dtype={self.array.dtype})"

def __str__(self) -> str:
return (
f"Dataset4dstem '{self.name}'\n"
f" shape: {self.shape}\n"
f" scan sampling: {self.sampling[:2]} {self.units[:2]}\n"
f" k sampling: {self.sampling[2:]} {self.units[2:]}"
)

@classmethod
def from_file(cls, file_path: str, file_type: str) -> "Dataset4dstem":
"""
Expand Down Expand Up @@ -365,6 +377,7 @@ def get_virtual_image(
final_mask = mask
elif mode is not None and geometry is not None:
# Create mask from mode and geometry
dp_shape = self.array.shape[-2:]
if mode == "circle":
if (
len(geometry) != 2
Expand All @@ -373,14 +386,14 @@ def get_virtual_image(
):
raise ValueError("For circle mode, geometry must be ((cy, cx), r)")
center, radius = geometry
final_mask = self._create_circle_mask(center, radius)
final_mask = create_circle_mask(dp_shape, center, radius)
elif mode == "annular":
if len(geometry) != 2 or len(geometry[0]) != 2 or len(geometry[1]) != 2:
raise ValueError(
"For annular mode, geometry must be ((cy, cx), (r_inner, r_outer))"
)
center, radii = geometry
final_mask = self._create_annular_mask(center, radii)
final_mask = create_annular_mask(dp_shape, center, radii)
else:
raise ValueError(
f"Unknown mode '{mode}'. Supported modes are 'circle' and 'annular'"
Expand Down Expand Up @@ -466,59 +479,6 @@ def get_virtual_image(

return virtual_image_dataset

def _create_circle_mask(self, center: tuple[float, float], radius: float) -> np.ndarray:
"""
Create a circular mask for virtual image formation.

Parameters
----------
center : tuple[float, float]
Center coordinates (cy, cx) of the circle
radius : float
Radius of the circle

Returns
-------
np.ndarray
Boolean mask with True inside the circle
"""
cy, cx = center
dp_shape = self.array.shape[-2:] # Get diffraction pattern dimensions
y, x = np.ogrid[: dp_shape[0], : dp_shape[1]]

# Calculate distance from center
distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2)

return distance <= radius

def _create_annular_mask(
self, center: tuple[float, float], radii: tuple[float, float]
) -> np.ndarray:
"""
Create an annular (ring-shaped) mask for virtual image formation.

Parameters
----------
center : tuple[float, float]
Center coordinates (cy, cx) of the annulus
radii : tuple[float, float]
Inner and outer radii (r_inner, r_outer) of the annulus

Returns
-------
np.ndarray
Boolean mask with True inside the annular region
"""
cy, cx = center
r_inner, r_outer = radii
dp_shape = self.array.shape[-2:] # Get diffraction pattern dimensions
y, x = np.ogrid[: dp_shape[0], : dp_shape[1]]

# Calculate distance from center
distance = np.sqrt((y - cy) ** 2 + (x - cx) ** 2)

return (distance >= r_inner) & (distance <= r_outer)

def show_virtual_images(self, figsize: tuple[int, int] | None = None, **kwargs) -> tuple:
"""
Display all virtual images stored in the dataset using show_2d.
Expand Down
Loading