Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b165443
adding Lattice class
cophus Sep 7, 2025
2ba99f1
first attempt at adding atoms
cophus Sep 7, 2025
66985f8
Polarization measurement and plotting
cophus Sep 7, 2025
781dbc8
Updating polarization plots
cophus Sep 7, 2025
ee6b54f
Adding csv output to Vector
cophus Sep 7, 2025
4b2ce91
Added iterative refinement for calculating lattice vectors (especiall…
darshan-mali Sep 10, 2025
e8cb38b
Removed unnecessary band aids
darshan-mali Sep 11, 2025
56f7ab9
Added fractional coordinate based polarization calculation.
darshan-mali Sep 16, 2025
afc9675
Fixed typo in comments of previous commit
darshan-mali Sep 16, 2025
6b089ca
Fixed rotation and cropping of polarization image.
darshan-mali Sep 16, 2025
e15dce4
small bug fix to make work for single atom structures
smribet Sep 22, 2025
f44e9fd
Added cifreader. Also fixed minor bug in polarization image.
darshan-mali Sep 22, 2025
685ccc1
Removed cifreader.py from tracking
darshan-mali Sep 23, 2025
54a7bed
Removed CIFReader from init
darshan-mali Sep 25, 2025
0bf173e
Fixed polarization plot. Used lattice vector calculation throughout.
darshan-mali Oct 3, 2025
fe845cf
Merge branch 'dev' into imaging
darshan-mali Oct 10, 2025
714de6d
Fixed error in merging
darshan-mali Oct 10, 2025
55e5ba9
Fixed polarization (removed accidental filtering).
darshan-mali Oct 10, 2025
91cd01a
Fixed dimension handling of origin, u, v
darshan-mali Oct 13, 2025
da47abd
Added order parameter calculation. Also added helper functions for pl…
darshan-mali Oct 21, 2025
472364e
Implemented GMM using torch. Removed skimage as a dependency.
darshan-mali Oct 24, 2025
3e6ec3d
Merge remote-tracking branch 'origin/dev' into imaging
darshan-mali Oct 24, 2025
08ceec3
Added pytests. Fixed pytest errors.
darshan-mali Oct 25, 2025
00fc201
Fixed block indent in visulaization.py that was causing pytest failure.
darshan-mali Oct 27, 2025
95cf28c
Updated order parameter plots. Also enabled plotting with custom colo…
darshan-mali Oct 28, 2025
2595b35
Fixed pytest bug in TorchGMM. This bug occured only in Github. The py…
darshan-mali Oct 28, 2025
26697c2
Made terminology consistent. Minor updates to defaults.
darshan-mali Oct 30, 2025
564eb63
Added all changes requested in first round of PR comments except visu…
darshan-mali Jan 7, 2026
d5459cb
Merge remote-tracking branch 'upstream/dev' into imaging
darshan-mali Jan 7, 2026
054047f
Merge remote-tracking branch 'upstream/dev' into imaging
darshan-mali Jan 12, 2026
322061e
Added pytests for Lattice AutoSerialize implementation. Reorganised a…
darshan-mali Jan 21, 2026
143784e
Merge remote-tracking branch 'upstream/dev' into imaging
darshan-mali Jan 21, 2026
1700c6e
Added visualize_order_parameter(). Added plot_polarization_legend(). …
darshan-mali Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 103 additions & 11 deletions src/quantem/core/datastructures/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import numbers
import os
from pathlib import Path
from typing import Any, Literal, Optional, Self, Union, overload
from typing import Any, Literal, Optional, Self, overload

import numpy as np
from numpy.typing import DTypeLike, NDArray
Expand Down Expand Up @@ -52,7 +52,9 @@ def __init__(
super().__init__()
arr = ensure_valid_array(array)
if not isinstance(arr, np.ndarray):
raise TypeError("Dataset requires a NumPy array (CuPy is not supported on this branch).")
raise TypeError(
"Dataset requires a NumPy array (CuPy is not supported on this branch)."
)
self._array = arr
self.name = name
self.origin = origin
Expand Down Expand Up @@ -97,7 +99,9 @@ def from_array(
"""
validated_array = ensure_valid_array(array)
if not isinstance(validated_array, np.ndarray):
raise TypeError("Dataset requires a NumPy array (CuPy is not supported on this branch).")
raise TypeError(
"Dataset requires a NumPy array (CuPy is not supported on this branch)."
)
_ndim = validated_array.ndim

# Set defaults if None
Expand Down Expand Up @@ -126,7 +130,9 @@ def array(self) -> NDArray:
def array(self, value: NDArray) -> None:
arr = ensure_valid_array(value, ndim=self.ndim) # want to allow changing dtype
if not isinstance(arr, np.ndarray):
raise TypeError("Dataset requires a NumPy array (CuPy is not supported on this branch).")
raise TypeError(
"Dataset requires a NumPy array (CuPy is not supported on this branch)."
)
self._array = arr
# self._array = ensure_valid_array(value, dtype=self.dtype, ndim=self.ndim)

Expand Down Expand Up @@ -593,6 +599,7 @@ def bin(
reshape_dims.append(effective_lengths[a1])
running_axis += 1

# --- Perform block reduction ---
array_view = self.array[tuple(slices)].reshape(tuple(reshape_dims))
array_binned = np.sum(array_view, axis=tuple(reduce_axes))
if reducer_norm == "mean":
Expand Down Expand Up @@ -628,11 +635,11 @@ def bin(

def fourier_resample(
self,
out_shape: Optional[tuple[int, ...]] = None,
factors: Optional[Union[float, tuple[float, ...]]] = None,
axes: Optional[tuple[int, ...]] = None,
out_shape: tuple[int, ...] | None = None,
factors: float | tuple[float, ...] | None = None,
axes: tuple[int, ...] | None = None,
modify_in_place: bool = False,
) -> Optional["Dataset"]:
) -> Self | None:
"""
Fourier resample the dataset by centered cropping (downsample) or zero padding (upsample).
The operation is performed in the Fourier domain using fftshift alignment and default FFT
Expand Down Expand Up @@ -676,7 +683,9 @@ def fourier_resample(
factors = tuple(float(f) for f in factors)
if len(factors) != len(axes):
raise ValueError("factors length must match number of axes.")
out_shape = tuple(max(1, int(round(self.shape[a1] * f))) for a1, f in zip(axes, factors))
out_shape = tuple(
max(1, int(round(self.shape[a1] * f))) for a1, f in zip(axes, factors)
)
else:
if len(out_shape) != len(axes):
raise ValueError("out_shape length must match number of axes.")
Expand Down Expand Up @@ -768,6 +777,87 @@ def _shift_center_index(n: int) -> int:
ds.origin = new_origin
return ds

def transpose(
self,
order: tuple[int, ...] | None = None,
modify_in_place: bool = False,
) -> Self | None:
"""
Transpose (permute) axes of the dataset and reorder metadata accordingly.

Parameters
----------
order : tuple[int, ...], optional
A permutation of range(self.ndim). If None, axes are reversed (NumPy's default).
modify_in_place : bool, default False
If True, modify this dataset in place. Otherwise return a new Dataset.

Returns
-------
Dataset or None
Transposed dataset if modify_in_place is False, otherwise None.
"""
if order is None:
order = tuple(range(self.ndim - 1, -1, -1))

if len(order) != self.ndim or set(order) != set(range(self.ndim)):
raise ValueError(f"'order' must be a permutation of 0..{self.ndim - 1}; got {order!r}")

array_t = self.array.transpose(order)

# Reorder metadata to match new axis order
new_origin = self.origin[list(order)].copy()
new_sampling = self.sampling[list(order)].copy()
new_units = [self.units[ax] for ax in order]

if modify_in_place:
# Use private attrs to avoid dtype/ndim enforcement in the setter
self._array = array_t
self._origin = new_origin
self._sampling = new_sampling
self._units = new_units
return None

# Create a new Dataset without extra array copies
return type(self).from_array(
array=array_t,
name=self.name, # keep name unchanged for now
origin=new_origin,
sampling=new_sampling,
units=new_units,
signal_units=self.signal_units,
)

def astype(
self,
dtype: DTypeLike,
copy: bool = True,
modify_in_place: bool = False,
) -> Self | None:
"""
Cast the array to a new dtype. Metadata is unchanged.

Parameters
----------
dtype : DTypeLike
Target dtype (e.g., np.float32, "complex64", etc.).
copy : bool, default True
If False and no cast is needed, a view may be returned by the backend.
modify_in_place : bool, default False
If True, modify this dataset in place. Otherwise return a new Dataset.

Returns
-------
Dataset or None
Dtype-cast dataset if modify_in_place is False, otherwise None.
"""
array_cast = self.array.astype(dtype, copy=copy)

if modify_in_place:
# Bypass the array setter so we can actually change dtype
self._array = array_cast
return None

def __getitem__(self, index) -> Self:
"""
General indexing method for Dataset objects.
Expand Down Expand Up @@ -806,7 +896,9 @@ def __getitem__(self, index) -> Self:
kept_axes = [i for i, idx in enumerate(index) if not isinstance(idx, (int, np.integer))]

# Slice/reduce metadata accordingly
new_origin = np.asarray(self.origin)[kept_axes] if np.ndim(self.origin) > 0 else self.origin
new_origin = (
np.asarray(self.origin)[kept_axes] if np.ndim(self.origin) > 0 else self.origin
)
new_sampling = (
np.asarray(self.sampling)[kept_axes] if np.ndim(self.sampling) > 0 else self.sampling
)
Expand Down
183 changes: 162 additions & 21 deletions src/quantem/core/datastructures/vector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import (
Any,
List,
Expand Down Expand Up @@ -161,7 +162,7 @@ def __init__(
@classmethod
def from_shape(
cls,
shape: Tuple[int, ...],
shape: Union[int, np.integer, Tuple[int, ...], Sequence[int]],
num_fields: Optional[int] = None,
fields: Optional[List[str]] = None,
units: Optional[List[str]] = None,
Expand All @@ -172,25 +173,42 @@ def from_shape(

Parameters
----------
shape : Tuple[int, ...]
The shape of the vector (dimensions)
num_fields : Optional[int]
Number of fields in the vector
name : Optional[str]
Name of the vector
fields : Optional[List[str]]
List of field names
units : Optional[List[str]]
List of units for each field
shape
The fixed indexed dimensions of the ragged vector.
Accepts:
- int / np.integer -> treated as (int,)
- tuple[int, ...] -> used as-is
- sequence[int] -> converted to tuple[int, ...]
- () -> 0-D (no indexed dims)
num_fields
Number of fields in the vector (ignored if `fields` is provided).
fields
List of field names (mutually exclusive with `num_fields`).
units
Unit strings per field. If None, defaults are used.
name
Optional name.

Returns
-------
Vector
A new Vector instance
A new Vector instance.
"""
validated_shape = validate_shape(shape)
# --- Normalize 'shape' to a tuple[int, ...] to satisfy validate_shape ---
if isinstance(shape, (int, np.integer)):
shape_tuple: Tuple[int, ...] = (int(shape),)
elif isinstance(shape, tuple):
shape_tuple = tuple(int(s) for s in shape)
elif isinstance(shape, Sequence):
shape_tuple = tuple(int(s) for s in shape)
else:
raise TypeError(f"Unsupported type for shape: {type(shape)}")

# validate_shape expects a tuple and applies your project-specific checks
validated_shape = validate_shape(shape_tuple)
ndim = len(validated_shape)

# --- Fields / num_fields handling (unchanged) ---
if fields is not None:
validated_fields = validate_fields(fields)
validated_num_fields = len(validated_fields)
Expand Down Expand Up @@ -446,16 +464,18 @@ def __getitem__(
np.asarray(i) if isinstance(i, (list, np.ndarray)) else i for i in normalized
)

# Check if we should return a numpy array (all indices are integers)
return_np = all(isinstance(i, (int, np.integer)) for i in idx_converted[: len(self.shape)])
# Check if we should return a single-cell view (all indices are integers)
return_cell = all(
isinstance(i, (int, np.integer)) for i in idx_converted[: len(self.shape)]
)
if len(idx_converted) < len(self.shape):
return_np = False
return_cell = False

if return_np:
view = self._data
for i in idx_converted:
view = view[i]
return cast(NDArray[Any], view)
if return_cell:
# Return a CellView so atoms[0]['x'] works;
# still behaves like ndarray via __array__ when used numerically.
indices_tuple = tuple(int(i) for i in idx_converted[: len(self.shape)])
return _CellView(self, indices_tuple)

# Handle fancy indexing and slicing
def get_indices(dim_idx: Any, dim_size: int) -> np.ndarray:
Expand Down Expand Up @@ -1024,3 +1044,124 @@ def __getitem__(
def __array__(self) -> np.ndarray:
"""Convert to numpy array when needed."""
return self.flatten()


class _CellView:
"""
View over a single Vector cell (fixed indices over the indexed dims).
Supports item access by field name, e.g., v[0]['x'] -> 1D array for that cell.
Behaves like a numpy array via __array__ for backward compatibility.
"""

def __init__(self, vector: "Vector", indices: Tuple[int, ...]) -> None:
self.vector = vector
self.indices = indices # tuple of ints, one per indexed dimension

@property
def array(self) -> NDArray:
ref = self.vector._data
for i in self.indices:
ref = ref[i]
return ref # shape: (rows, num_fields)

def __array__(self) -> np.ndarray:
# Allows numpy to transparently consume this as an ndarray
return self.array

def __getitem__(self, field_name: str) -> NDArray:
if not isinstance(field_name, str):
raise TypeError("Use a field name string, e.g. cell['x']")
if field_name not in self.vector._fields:
raise KeyError(f"Field '{field_name}' not found.")
j = self.vector._fields.index(field_name)
return self.array[:, j]

def save_csv(
self,
filename: str,
*,
# Jupyter-friendly defaults:
jupyter_friendly: bool = True,
include_units: bool = True,
delimiter: str = ",",
float_fmt: str = "%.6g",
append_csv_ext: bool = True,
create_dirs: bool = True,
# Legacy/optional extras (ignored when jupyter_friendly=True):
add_comment_header: bool = False, # writes a leading "# ..." line
add_units_row: bool = False, # writes a separate units row
) -> str:
"""
Save this cell's rows to a CSV file.

If jupyter_friendly=True (default), writes a single header row suitable
for JupyterLab's CSV viewer. Units are merged into the column names
as 'field (unit)'. No extra header lines.

If jupyter_friendly=False, you can enable:
- add_comment_header=True -> a commented first line
- add_units_row=True -> a second line with units only
"""
import csv
import os

import numpy as np

path = os.fspath(filename)
if append_csv_ext and not path.lower().endswith(".csv"):
path += ".csv"

parent = os.path.dirname(path)
if parent and create_dirs:
os.makedirs(parent, exist_ok=True)

arr = self.array
fields = list(self.vector.fields)
units = list(self.vector.units)

# Build header row
if jupyter_friendly:
if include_units:
header = [f"{n} ({u})" for n, u in zip(fields, units)]
else:
header = fields
write_comment = False
write_units_row = False
else:
header = fields
write_comment = bool(add_comment_header)
write_units_row = bool(add_units_row and include_units)

# Prepare a small formatter to apply float_fmt to numeric values
def fmt_row(row: np.ndarray) -> list[str]:
out = []
for v in row:
try:
out.append(float_fmt % float(v))
except Exception:
out.append(str(v))
return out

with open(path, "w", newline="") as f:
w = csv.writer(f, delimiter=delimiter, quoting=csv.QUOTE_MINIMAL)

# Optional legacy comment header
if write_comment:
vec_name = getattr(self.vector, "name", "Vector")
idx_str = ", ".join(str(i) for i in self.indices)
nrows = 0 if (not isinstance(arr, np.ndarray)) else int(arr.shape[0])
f.write(f"# {vec_name} — cell indices ({idx_str}), rows={nrows}\n")

# Header row (always)
w.writerow(header)

# Optional legacy separate units row
if write_units_row:
w.writerow(units)

# Data rows
if isinstance(arr, np.ndarray) and arr.size:
for r in range(arr.shape[0]):
w.writerow(fmt_row(arr[r, :]))

return path
Loading