diff --git a/ndviewer_light/core.py b/ndviewer_light/core.py index 923fb40..2c2e258 100644 --- a/ndviewer_light/core.py +++ b/ndviewer_light/core.py @@ -1447,6 +1447,13 @@ def __init__(self, dataset_path: str = ""): self._max_fov_per_time: Dict[int, int] = {} # timepoint -> max FOV index seen self._image_height: int = 0 self._image_width: int = 0 + self._pixel_size_um: Optional[float] = None # XY pixel size in micrometers + self._dz_um: Optional[float] = None # Z step size in micrometers + self._is_ome_format: bool = False # True if dataset is OME-TIFF format + self._ome_file_index: Dict[int, str] = {} # flat_fov_idx -> OME filepath + self._pull_mode: bool = ( + False # True if using pre-built 6D array (fast navigation) + ) self._plane_cache = MemoryBoundedLRUCache(PLANE_CACHE_MAX_MEMORY_BYTES) self._updating_sliders: bool = False # Prevent recursive updates self._acquisition_active: bool = False # True during live acquisition @@ -1566,6 +1573,10 @@ def _setup_ui(self): fov_layout.addWidget(self._fov_slider) slider_layout.addWidget(self._fov_slider_container) + # Store slider container reference (will be moved into NDV's layout later) + self._slider_container = slider_container + + # Initially add to our layout (will be repositioned for pull mode) layout.addWidget(slider_container) self.setLayout(layout) @@ -1578,28 +1589,34 @@ def _on_time_slider_changed(self, value: int): self._current_time_idx = value self._time_label.setText(f"T: {value}") - # Update FOV slider max for this timepoint - self._updating_sliders = True - try: - available_fov_max = self._max_fov_per_time.get(value, 0) - self._fov_slider.setMaximum(available_fov_max) - - # Clamp current FOV if it exceeds available range - if self._current_fov_idx > available_fov_max: - self._current_fov_idx = available_fov_max - self._fov_slider.setValue(available_fov_max) - - # Update FOV label to reflect current FOV after any clamping - if self._fov_labels and self._current_fov_idx < len(self._fov_labels): - self._fov_label.setText( - f"FOV: {self._fov_labels[self._current_fov_idx]}" - ) - else: - self._fov_label.setText(f"FOV: {self._current_fov_idx}") - finally: - self._updating_sliders = False + if self._pull_mode: + # Pull mode: navigate NDV directly (fast, no array rebuild) + self._navigate_ndv("time", value) + else: + # Push mode: update FOV slider range and rebuild array + self._updating_sliders = True + try: + available_fov_max = self._max_fov_per_time.get(value, 0) + self._fov_slider.setMaximum(available_fov_max) + + # Clamp current FOV if it exceeds available range + if self._current_fov_idx > available_fov_max: + self._current_fov_idx = available_fov_max + self._fov_slider.setValue(available_fov_max) + + # Update FOV label to reflect current FOV after any clamping + if self._fov_labels and self._current_fov_idx < len( + self._fov_labels + ): + self._fov_label.setText( + f"FOV: {self._fov_labels[self._current_fov_idx]}" + ) + else: + self._fov_label.setText(f"FOV: {self._current_fov_idx}") + finally: + self._updating_sliders = False - self._load_current_position() + self._load_current_position() def _on_fov_slider_changed(self, value: int): """Handle FOV slider change.""" @@ -1612,7 +1629,44 @@ def _on_fov_slider_changed(self, value: int): self._fov_label.setText(f"FOV: {self._fov_labels[value]}") else: self._fov_label.setText(f"FOV: {value}") - self._load_current_position() + + if self._pull_mode: + # Pull mode: navigate NDV directly (fast, no array rebuild) + self._navigate_ndv("fov", value) + else: + # Push mode: rebuild array for new FOV + self._load_current_position() + + def _navigate_ndv(self, dim: str, value: int): + """Navigate NDV viewer to a specific dimension index. + + Used in pull mode for fast navigation without array rebuilds. + + Args: + dim: Dimension name ("time", "fov", "z", "channel") + value: Index value to navigate to + """ + if not self.ndv_viewer: + return + + try: + # NDV ArrayViewer uses display_model.current_index + if hasattr(self.ndv_viewer, "display_model"): + dm = self.ndv_viewer.display_model + if hasattr(dm, "current_index") and dim in dm.current_index: + dm.current_index[dim] = value + return + + # Fallback for older NDV versions + if hasattr(self.ndv_viewer, "dims"): + dims = self.ndv_viewer.dims + if hasattr(dims, "current_step"): + current = dict(dims.current_step) + if dim in current: + current[dim] = value + dims.current_step = current + except Exception as e: + logger.debug("Failed to navigate NDV: %s", e) def _load_current_position(self): """Load data for current position, dispatching to appropriate loader. @@ -1724,6 +1778,7 @@ def start_acquisition( self._file_index.clear() self._plane_cache.clear() self._max_fov_per_time.clear() + self._pull_mode = False # Push mode: rebuild arrays on navigation # Store configuration self._channel_names = list(channels) @@ -1797,6 +1852,12 @@ def _rebuild_viewer_for_acquisition(self): xarr.attrs["luts"] = self._luts xarr.attrs["channel_names"] = self._channel_names + # Include pixel size metadata if available (for scale display and 3D rendering) + if self._pixel_size_um is not None: + xarr.attrs["pixel_size_um"] = self._pixel_size_um + if self._dz_um is not None: + xarr.attrs["dz_um"] = self._dz_um + self._xarray_data = xarr self._set_ndv_data(xarr) @@ -1970,10 +2031,13 @@ def _load_single_plane( ) -> np.ndarray: """Load a single image plane from cache or disk. + Handles both single-TIFF (one plane per file) and OME-TIFF (all planes + in one file per FOV) formats. + Args: t: Timepoint index fov_idx: FOV index - z: Z-level value + z: Z-level value (index into _z_levels for single-TIFF, direct index for OME) channel: Channel name Returns: @@ -1999,10 +2063,19 @@ def _load_single_plane( return np.zeros((self._image_height, self._image_width), dtype=np.uint16) try: - with tf.TiffFile(filepath) as tif: - plane = tif.pages[0].asarray() - self._plane_cache.put(cache_key, plane) - return plane + # Check if this is OME-TIFF format (multi-plane file) + is_ome = getattr(self, "_is_ome_format", False) + + if is_ome: + # OME-TIFF: read specific plane from multi-dimensional file + plane = self._load_ome_plane(filepath, t, z, channel) + else: + # Single-TIFF: one plane per file + with tf.TiffFile(filepath) as tif: + plane = tif.pages[0].asarray() + + self._plane_cache.put(cache_key, plane) + return plane except FileNotFoundError: logger.warning("Image file not found (may have been deleted): %s", filepath) except PermissionError as e: @@ -2022,6 +2095,51 @@ def _load_single_plane( # Return zeros on error - user sees black image return np.zeros((self._image_height, self._image_width), dtype=np.uint16) + def _load_ome_plane( + self, filepath: str, t: int, z: int, channel: str + ) -> np.ndarray: + """Load a single plane from an OME-TIFF file. + + Args: + filepath: Path to OME-TIFF file + t: Timepoint index + z: Z-level index + channel: Channel name + + Returns: + Image plane as numpy array + """ + # Get channel index from name + try: + c_idx = self._channel_names.index(channel) + except ValueError: + logger.warning("Channel '%s' not found in channel list", channel) + return np.zeros((self._image_height, self._image_width), dtype=np.uint16) + + with tf.TiffFile(filepath) as tif: + series = tif.series[0] + axes = series.axes + shape = series.shape + + # Build index based on axes order (commonly TZCYX or TCYX) + idx = [] + for ax in axes: + if ax == "T": + idx.append(t) + elif ax == "Z": + idx.append(z) + elif ax == "C": + idx.append(c_idx) + elif ax in ("Y", "X"): + idx.append(slice(None)) + else: + # Unknown axis, take first element + idx.append(0) + + # Read the specific plane + data = series.asarray()[tuple(idx)] + return data + def _load_current_fov(self): """Load and display data for the current FOV position. @@ -2087,6 +2205,12 @@ def _update_ndv_data(self, data): xarr.attrs["luts"] = self._luts xarr.attrs["channel_names"] = self._channel_names + # Include pixel size metadata if available (for scale display and 3D rendering) + if self._pixel_size_um is not None: + xarr.attrs["pixel_size_um"] = self._pixel_size_um + if self._dz_um is not None: + xarr.attrs["dz_um"] = self._dz_um + self._xarray_data = xarr # Try in-place update to avoid flickering @@ -3011,30 +3135,89 @@ def _data_structure_changed( return True def load_dataset(self, path: str): - """Load dataset and display in NDV.""" + """Load dataset with pre-built 6D array for fast navigation. + + Uses a hybrid approach: + - Builds 6D array once (like original implementation) for fast slicing + - Custom T/FOV sliders navigate via NDV's API (no array rebuilds) + - NDV's built-in time/fov sliders are hidden to avoid duplicates + + This provides the unified slider UI while maintaining performance. + """ # Close any previously open file handles before loading new dataset self._close_open_handles() - # Reset state when loading a new dataset to ensure clean slate. - # This prevents stale channel controls from persisting when switching - # between datasets with different channel configurations. + # Stop any running play animations and pending loads + self._stop_play_animation(self._time_play_timer, self._time_play_btn) + self._stop_play_animation(self._fov_play_timer, self._fov_play_btn) + if self._load_debounce_timer and self._load_debounce_timer.isActive(): + self._load_debounce_timer.stop() + self._load_pending = False + + # Reset state self._last_sig = None self._xarray_data = None - + self._pull_mode = True # Enable fast navigation mode self.dataset_path = path self.status_label.setText(f"Loading: {Path(path).name}...") QApplication.processEvents() try: + # Build 6D array using the optimized lazy loading path data = self._create_lazy_array(Path(path)) if data is not None: - self._xarray_data = data # Store for profiling + self._xarray_data = data self._open_handles = data.attrs.get("_open_tifs", []) - # Always do full rebuild when explicitly loading a new dataset. - # This ensures channels/LUTs are properly reset. + + # Extract metadata for slider configuration + self._channel_names = data.attrs.get("channel_names", []) + self._luts = data.attrs.get("luts", {}) + self._pixel_size_um = data.attrs.get("pixel_size_um") + self._dz_um = data.attrs.get("dz_um") + + # Get dimension sizes for slider ranges + n_time = data.sizes.get("time", 1) + n_fov = data.sizes.get("fov", 1) + + # Build FOV labels from discovered FOVs + fmt = detect_format(Path(path)) + fovs = self._discover_fovs(Path(path), fmt) + self._fov_labels = [f"{f['region']}:{f['fov']}" for f in fovs] + + # Configure custom sliders + self._max_time_idx = n_time - 1 + self._updating_sliders = True + try: + # Time slider + self._time_slider.setMaximum(self._max_time_idx) + self._time_slider.setValue(0) + self._time_label.setText("T: 0") + self._time_container.setVisible(self._max_time_idx > 0) + + # FOV slider + max_fov = n_fov - 1 + self._fov_slider.setMaximum(max_fov) + self._fov_slider.setValue(0) + if self._fov_labels: + self._fov_label.setText(f"FOV: {self._fov_labels[0]}") + else: + self._fov_label.setText("FOV: 0") + finally: + self._updating_sliders = False + + # Reset navigation state + self._current_time_idx = 0 + self._current_fov_idx = 0 + + # Display the data (builds NDV viewer with 6D array) self._set_ndv_data(data) - # Update status (keep it stable during live acquisition; avoid printing dims like time=...) + # Hide NDV's time/fov sliders since we use custom ones + self._hide_ndv_dimension_sliders(["time", "fov"]) + + # Move custom sliders into NDV's layout for better visual grouping + self._insert_sliders_into_ndv_layout() + self.status_label.setText(f"Loaded: {Path(path).name}") else: self.status_label.setText("Failed to load dataset") @@ -3044,6 +3227,401 @@ def load_dataset(self, path: str): traceback.print_exc() + def _hide_ndv_dimension_sliders(self, dims_to_hide: List[str]): + """Hide NDV's built-in sliders for specific dimensions. + + Used in pull mode to avoid duplicate sliders - we use custom T/FOV + sliders while NDV handles z/channel. + + Args: + dims_to_hide: List of dimension names to hide (e.g., ["time", "fov"]) + """ + if not self.ndv_viewer: + return + + try: + # Use NDV's official hide_sliders API + # show_remainder=False prevents showing sliders for visible axes (x, y) + if hasattr(self.ndv_viewer, "_view") and hasattr( + self.ndv_viewer._view, "hide_sliders" + ): + self.ndv_viewer._view.hide_sliders(dims_to_hide, show_remainder=False) + except Exception as e: + logger.debug("Could not hide NDV sliders: %s", e) + + def _insert_sliders_into_ndv_layout(self): + """Move custom T/FOV sliders into NDV's internal layout. + + This places our sliders right after NDV's dimension sliders (z, channel) + for a cohesive visual grouping, instead of at the bottom of the window. + + NDV's layout structure: + - _view.frontend_widget() -> QWidget with QVBoxLayout + - [0] QSplitter + - widget(0) -> QWidget with QVBoxLayout + - [0] QWidget (toolbar) + - [1] CanvasBackendDesktop + - [2] _QDimsSliders <- insert after this + - [3] _UpCollapsible (LUT controls) + - [4] QWidget (footer) + """ + if not self.ndv_viewer or not hasattr(self, "_slider_container"): + return + + try: + # Navigate NDV's internal structure + if not hasattr(self.ndv_viewer, "_view"): + return + + frontend = self.ndv_viewer._view.frontend_widget() + if not frontend or not frontend.layout(): + return + + # Get the QSplitter from frontend's layout + splitter_item = frontend.layout().itemAt(0) + if not splitter_item or not splitter_item.widget(): + return + + splitter = splitter_item.widget() + if splitter.count() == 0: + return + + # Get the main content widget (first child of splitter) + main_widget = splitter.widget(0) + if not main_widget or not main_widget.layout(): + return + + main_layout = main_widget.layout() + + # Find the index of _QDimsSliders in the layout + dims_slider_idx = -1 + for i in range(main_layout.count()): + item = main_layout.itemAt(i) + if item and item.widget(): + widget_class = item.widget().__class__.__name__ + if "DimSliders" in widget_class or "Dims" in widget_class: + dims_slider_idx = i + break + + # Remove slider container from our main layout + our_layout = self.layout() + if our_layout: + our_layout.removeWidget(self._slider_container) + + # Insert into NDV's layout right after the dims sliders + if dims_slider_idx >= 0: + main_layout.insertWidget(dims_slider_idx + 1, self._slider_container) + else: + # Fallback: add after canvas (index 2) + insert_pos = min(2, main_layout.count()) + main_layout.insertWidget(insert_pos, self._slider_container) + + logger.debug( + "Inserted custom sliders into NDV layout at position %d", + dims_slider_idx + 1, + ) + except Exception as e: + logger.debug("Could not insert sliders into NDV layout: %s", e) + + def _on_ndv_ndims_requested(self, ndims: int): + """Handle NDV's nDimsRequested signal (fired when 2D/3D toggle is clicked). + + When NDV switches between 2D and 3D modes, it recreates its dimension sliders. + We need to re-hide the time/fov sliders to prevent duplicates with our custom ones. + + Args: + ndims: Number of dimensions requested (2 or 3) + """ + if self._pull_mode: + # Use QTimer to defer hiding until after NDV finishes recreating sliders + from PyQt5.QtCore import QTimer + + QTimer.singleShot( + 50, lambda: self._hide_ndv_dimension_sliders(["time", "fov"]) + ) + + def _scan_dataset_to_internal_state(self, base_path: Path) -> bool: + """Scan filesystem and populate internal state for push-mode architecture. + + This method discovers all files in the dataset and sets up: + - _file_index: maps (t, fov_idx, z, channel) to filepath + - _fov_labels: list of FOV labels like ["A1:0", "A1:1", ...] + - _channel_names: sorted list of channel names + - _z_levels: sorted list of z-level indices + - _image_height, _image_width: image dimensions + - _luts: channel colormaps based on wavelengths + - _max_time_idx: highest timepoint index + - _max_fov_per_time: maps timepoint to max FOV index for that timepoint + + Args: + base_path: Path to the dataset directory + + Returns: + True if successful, False otherwise + """ + if not LAZY_LOADING_AVAILABLE: + return False + + fmt = detect_format(base_path) + fovs = self._discover_fovs(base_path, fmt) + + if not fovs: + logger.warning("No FOVs found in dataset") + return False + + # Clear previous state + with self._file_index_lock: + self._file_index.clear() + self._plane_cache.clear() + self._max_fov_per_time.clear() + + # Build FOV label list and reverse lookup + self._fov_labels = [f"{f['region']}:{f['fov']}" for f in fovs] + fov_to_flat = {(f["region"], f["fov"]): i for i, f in enumerate(fovs)} + + # Scan files based on format + if fmt == "ome_tiff": + return self._scan_ome_tiff_to_state(base_path, fov_to_flat) + else: + return self._scan_single_tiff_to_state(base_path, fov_to_flat) + + def _scan_single_tiff_to_state( + self, base_path: Path, fov_to_flat: Dict[tuple, int] + ) -> bool: + """Scan single-TIFF format dataset into internal state. + + Args: + base_path: Path to dataset directory + fov_to_flat: Maps (region, fov) to flat FOV index + + Returns: + True if successful, False otherwise + """ + channels_seen: set = set() + z_levels_seen: set = set() + times_seen: set = set() + height, width = 0, 0 + + # Scan all timepoint directories + for tp_dir in sorted(base_path.iterdir()): + if not (tp_dir.is_dir() and tp_dir.name.isdigit()): + continue + t = int(tp_dir.name) + has_files = False + + for f in tp_dir.iterdir(): + if f.suffix.lower() not in TIFF_EXTENSIONS: + continue + m = FPATTERN.search(f.name) + if not m: + continue + + region = m.group("r") + fov = int(m.group("f")) + z = int(m.group("z")) + channel = m.group("c") + + # Convert (region, fov) to flat index + flat_fov = fov_to_flat.get((region, fov)) + if flat_fov is None: + continue + + # Populate file index + with self._file_index_lock: + self._file_index[(t, flat_fov, z, channel)] = str(f) + + channels_seen.add(channel) + z_levels_seen.add(z) + has_files = True + + # Get image dimensions from first file + if height == 0: + try: + with tf.TiffFile(str(f)) as tif: + height, width = tif.pages[0].shape[-2:] + except Exception as e: + logger.debug("Failed to read image dimensions: %s", e) + + if has_files: + times_seen.add(t) + + if not self._file_index: + return False + + # Store discovered metadata + self._channel_names = sorted(channels_seen) + self._z_levels = sorted(z_levels_seen) + self._image_height = height + self._image_width = width + self._max_time_idx = max(times_seen) if times_seen else 0 + + # Set up LUTs based on channel wavelengths + self._luts = { + i: wavelength_to_colormap(extract_wavelength(c)) + for i, c in enumerate(self._channel_names) + } + + # Build max FOV per timepoint mapping + for t in times_seen: + fovs_for_t = set() + with self._file_index_lock: + for ft, fov_idx, z, ch in self._file_index.keys(): + if ft == t: + fovs_for_t.add(fov_idx) + if fovs_for_t: + self._max_fov_per_time[t] = max(fovs_for_t) + + # Read acquisition parameters for pixel size (stored for later use) + pixel_size_um, dz_um = read_acquisition_parameters(base_path) + self._pixel_size_um = pixel_size_um + self._dz_um = dz_um + + # Mark as non-OME format + self._is_ome_format = False + self._ome_file_index.clear() + + return True + + def _scan_ome_tiff_to_state( + self, base_path: Path, fov_to_flat: Dict[tuple, int] + ) -> bool: + """Scan OME-TIFF format dataset into internal state. + + Args: + base_path: Path to dataset directory + fov_to_flat: Maps (region, fov) to flat FOV index + + Returns: + True if successful, False otherwise + """ + ome_dir = base_path / "ome_tiff" + if not ome_dir.exists(): + ome_dir = next( + (d for d in base_path.iterdir() if d.is_dir() and d.name.isdigit()), + base_path, + ) + + # Find all OME files and map to FOVs + ome_files: Dict[int, str] = {} # flat_fov_idx -> filepath + for f in ome_dir.glob("*.ome.tif*"): + m = FPATTERN_OME.search(f.name) + if m: + region, fov = m.group("r"), int(m.group("f")) + flat_fov = fov_to_flat.get((region, fov)) + if flat_fov is not None: + ome_files[flat_fov] = str(f) + + if not ome_files: + return False + + # Read metadata from first OME file + first_file = next(iter(ome_files.values())) + try: + with tf.TiffFile(first_file) as tif: + series = tif.series[0] + axes = series.axes + shape = series.shape + shape_dict = dict(zip(axes, shape)) + + n_t = shape_dict.get("T", 1) + n_c = shape_dict.get("C", 1) + n_z = shape_dict.get("Z", 1) + height = shape_dict.get("Y", shape[-2]) + width = shape_dict.get("X", shape[-1]) + + # Extract channel names from OME metadata + channel_names = [] + pixel_size_x, pixel_size_y, pixel_size_z = None, None, None + if tif.ome_metadata: + try: + import xml.etree.ElementTree as ET + + root = ET.fromstring(tif.ome_metadata) + ns = { + "ome": "http://www.openmicroscopy.org/Schemas/OME/2016-06" + } + for ch in root.findall(".//ome:Channel", ns): + name = ch.get("Name") or ch.get("ID", "") + if name: + channel_names.append(name) + + pixel_size_x, pixel_size_y, pixel_size_z = ( + extract_ome_physical_sizes(tif.ome_metadata) + ) + except Exception as e: + logger.debug("Failed to parse OME metadata: %s", e) + + # Fallback channel names if not found in metadata + if not channel_names: + channel_names = [f"Ch{i}" for i in range(n_c)] + + except Exception as e: + logger.error("Failed to read OME file metadata: %s", e) + return False + + # Store metadata + self._channel_names = channel_names + self._z_levels = list(range(n_z)) + self._image_height = height + self._image_width = width + self._max_time_idx = n_t - 1 + self._pixel_size_um = pixel_size_x + self._dz_um = pixel_size_z + + # Set up LUTs + self._luts = { + i: wavelength_to_colormap(extract_wavelength(c)) + for i, c in enumerate(self._channel_names) + } + + # For OME-TIFF, we store the file path per FOV (not per plane) + # The _load_single_plane method needs to handle this differently + # Store OME file paths in a separate attribute for OME-TIFF loading + self._ome_file_index = ome_files + + # Build file index entries for all (t, fov, z, channel) combinations + # For OME-TIFF, the filepath is the same for all planes in a FOV + for flat_fov, filepath in ome_files.items(): + for t in range(n_t): + for z in range(n_z): + for ch_idx, ch_name in enumerate(channel_names): + with self._file_index_lock: + # Store as (t, fov, z, channel_name) -> filepath + # Also store channel index for OME reading + self._file_index[(t, flat_fov, z, ch_name)] = filepath + + # All FOVs available at all timepoints for OME-TIFF + for t in range(n_t): + self._max_fov_per_time[t] = len(ome_files) - 1 + + # Store OME-specific info for plane loading + self._is_ome_format = True + self._ome_axes = axes + self._ome_shape = shape + + return True + + def _configure_sliders_for_dataset(self): + """Configure T and FOV sliders based on discovered dataset structure.""" + self._updating_sliders = True + try: + # Time slider + self._time_slider.setMaximum(self._max_time_idx) + self._time_slider.setValue(0) + self._time_label.setText("T: 0") + self._time_container.setVisible(self._max_time_idx > 0) + + # FOV slider + max_fov = len(self._fov_labels) - 1 if self._fov_labels else 0 + self._fov_slider.setMaximum(max_fov) + self._fov_slider.setValue(0) + if self._fov_labels: + self._fov_label.setText(f"FOV: {self._fov_labels[0]}") + else: + self._fov_label.setText("FOV: 0") + finally: + self._updating_sliders = False + def set_current_index(self, dim: str, value: int) -> bool: """Set the current index for a dimension in the viewer. @@ -3799,6 +4377,12 @@ def _set_ndv_data(self, data: xr.DataArray): old_widget.deleteLater() layout.insertWidget(idx, self.ndv_viewer.widget(), 1) + # Connect to nDimsRequested signal to re-hide sliders when 3D mode is toggled + if hasattr(self.ndv_viewer, "_view") and hasattr( + self.ndv_viewer._view, "nDimsRequested" + ): + self.ndv_viewer._view.nDimsRequested.connect(self._on_ndv_ndims_requested) + # Update channel labels after viewer is ready. self._initiate_channel_label_update()