diff --git a/dascore/core/patch.py b/dascore/core/patch.py index 9ea9bd3a2..b301f9ca1 100644 --- a/dascore/core/patch.py +++ b/dascore/core/patch.py @@ -416,6 +416,7 @@ def iselect(self, *args, **kwargs): gaussian_filter = dascore.proc.gaussian_filter slope_filter = dascore.proc.slope_filter wiener_filter = dascore.proc.wiener_filter + adaptive_spectral_filter = dascore.proc.adaptive_spectral_filter abs = dascore.proc.abs conj = dascore.proc.conj real = dascore.proc.real diff --git a/dascore/proc/__init__.py b/dascore/proc/__init__.py index 40fc327db..cc292d4ff 100644 --- a/dascore/proc/__init__.py +++ b/dascore/proc/__init__.py @@ -18,3 +18,4 @@ from .hampel import hampel_filter from .wiener import wiener_filter from .align import align_to_coord +from .adaptive_spectral_filter import adaptive_spectral_filter diff --git a/dascore/proc/_adaptive_spectral_filter_numba.py b/dascore/proc/_adaptive_spectral_filter_numba.py new file mode 100644 index 000000000..30288f17a --- /dev/null +++ b/dascore/proc/_adaptive_spectral_filter_numba.py @@ -0,0 +1,317 @@ +"""Optional Numba/rocket-fft engine for adaptive spectral filtering.""" + +from __future__ import annotations + +import numba as nb +import numpy as np +import rocket_fft # noqa: F401 # registers FFT overloads with numba. + +from dascore.proc.adaptive_spectral_filter import ( + _finalize_output, + _prepare_work_arrays, + _validate_filter_inputs, +) + + +def _tile_indices_from_parity_index( + ind: int, + count1: int, + parity0: int, + parity1: int, +) -> tuple[int, int]: + """Map a flattened parity-local index back to full-grid tile indices.""" + ix = parity0 + 2 * (ind // count1) + iy = parity1 + 2 * (ind % count1) + return ix, iy + + +def _tile_bounds( + ix: int, + iy: int, + wx: int, + wy: int, + stride0: int, + stride1: int, + shape0: int, + shape1: int, +) -> tuple[int, int, int, int]: + """Return padded-array origin and valid tile shape for one window.""" + beg0 = ix * stride0 + beg1 = iy * stride1 + end0 = min(beg0 + wx, shape0) + end1 = min(beg1 + wy, shape1) + return beg0, beg1, end0 - beg0, end1 - beg1 + + +def _copy_padded_tile( + padded: np.ndarray, + tile: np.ndarray, + beg0: int, + beg1: int, + n0: int, + n1: int, +) -> None: + """Copy the valid padded-array region into a fixed-shape zeroed tile.""" + for i in range(n0): + for j in range(n1): + tile[i, j] = padded[beg0 + i, beg1 + j] + + +def _complex_power(value: complex) -> np.float32: + """Return the magnitude of a complex FFT coefficient as ``float32``.""" + return np.float32((value.real * value.real + value.imag * value.imag) ** 0.5) + + +def _max_spectral_power(spec: np.ndarray) -> np.float32: + """Return the maximum spectral magnitude in one tile.""" + max_power = np.float32(0.0) + for i in range(spec.shape[0]): + for j in range(spec.shape[1]): + power = _complex_power(spec[i, j]) + if power > max_power: + max_power = power + return max_power + + +def _apply_spectral_weight( + spec: np.ndarray, + exponent: float, + normalize_power: bool, +) -> None: + """Apply adaptive magnitude weighting to one tile spectrum in place.""" + max_power = np.float32(0.0) + if normalize_power: + max_power = _max_spectral_power(spec) + + for i in range(spec.shape[0]): + for j in range(spec.shape[1]): + power = _complex_power(spec[i, j]) + if normalize_power: + if max_power != 0.0: + power = power / max_power + else: + power = np.float32(0.0) + weight = np.float32(power**exponent) + spec[i, j] *= weight + + +def _overlap_add_tile( + filtered: np.ndarray, + tile: np.ndarray, + taper: np.ndarray, + beg0: int, + beg1: int, + n0: int, + n1: int, +) -> None: + """Accumulate the valid region of one filtered tile into the output.""" + for i in range(n0): + for j in range(n1): + filtered[beg0 + i, beg1 + j] += tile[i, j] * taper[i, j] + + +_tile_indices_from_parity_index_numba = nb.njit(cache=True, inline="always")( + _tile_indices_from_parity_index +) +_tile_bounds_numba = nb.njit(cache=True, inline="always")(_tile_bounds) +_copy_padded_tile_numba = nb.njit(cache=True, inline="always")(_copy_padded_tile) +_complex_power_numba = nb.njit(cache=True, inline="always")(_complex_power) + + +def _max_spectral_power_numba_impl(spec: np.ndarray) -> np.float32: + """Return the maximum spectral magnitude using compiled helpers.""" + max_power = np.float32(0.0) + for i in range(spec.shape[0]): + for j in range(spec.shape[1]): + power = _complex_power_numba(spec[i, j]) + if power > max_power: + max_power = power + return max_power + + +def _apply_spectral_weight_numba_impl( + spec: np.ndarray, + exponent: float, + normalize_power: bool, +) -> None: + """Apply adaptive magnitude weighting using compiled helpers.""" + max_power = np.float32(0.0) + if normalize_power: + max_power = _max_spectral_power_numba(spec) + + for i in range(spec.shape[0]): + for j in range(spec.shape[1]): + power = _complex_power_numba(spec[i, j]) + if normalize_power: + if max_power != 0.0: + power = power / max_power + else: + power = np.float32(0.0) + weight = np.float32(power**exponent) + spec[i, j] *= weight + + +_max_spectral_power_numba = nb.njit(cache=True, inline="always")( + _max_spectral_power_numba_impl +) +_apply_spectral_weight_numba = nb.njit(cache=True, inline="always")( + _apply_spectral_weight_numba_impl +) +_overlap_add_tile_numba = nb.njit(cache=True, inline="always")(_overlap_add_tile) + + +def _process_tile_group_python( + padded: np.ndarray, + filtered: np.ndarray, + taper: np.ndarray, + wx: int, + wy: int, + stride0: int, + stride1: int, + nx: int, + ny: int, + parity0: int, + parity1: int, + exponent: float, + normalize_power: bool, +) -> None: + """Process one non-overlapping tile parity group in pure Python.""" + count0 = (nx - parity0 + 1) // 2 + count1 = (ny - parity1 + 1) // 2 + count = count0 * count1 + for ind in range(count): + ix, iy = _tile_indices_from_parity_index(ind, count1, parity0, parity1) + beg0, beg1, n0, n1 = _tile_bounds( + ix, iy, wx, wy, stride0, stride1, padded.shape[0], padded.shape[1] + ) + + tile = np.zeros((wx, wy), dtype=np.float32) + _copy_padded_tile(padded, tile, beg0, beg1, n0, n1) + + spec = np.fft.rfft2(tile) + if exponent != 0.0: + _apply_spectral_weight(spec, exponent, normalize_power) + + tile = np.fft.irfft2(spec, s=(wx, wy)) + _overlap_add_tile(filtered, tile, taper, beg0, beg1, n0, n1) + + +def _process_tile_group_numba_impl( + padded: np.ndarray, + filtered: np.ndarray, + taper: np.ndarray, + wx: int, + wy: int, + stride0: int, + stride1: int, + nx: int, + ny: int, + parity0: int, + parity1: int, + exponent: float, + normalize_power: bool, +) -> None: + """Process one non-overlapping tile parity group with compiled helpers.""" + count0 = (nx - parity0 + 1) // 2 + count1 = (ny - parity1 + 1) // 2 + count = count0 * count1 + for ind in nb.prange(count): # type: ignore[not-iterable] + ix, iy = _tile_indices_from_parity_index_numba(ind, count1, parity0, parity1) + beg0, beg1, n0, n1 = _tile_bounds_numba( + ix, iy, wx, wy, stride0, stride1, padded.shape[0], padded.shape[1] + ) + + tile = np.zeros((wx, wy), dtype=np.float32) + _copy_padded_tile_numba(padded, tile, beg0, beg1, n0, n1) + + spec = np.fft.rfft2(tile) + if exponent != 0.0: + _apply_spectral_weight_numba(spec, exponent, normalize_power) + + tile = np.fft.irfft2(spec, s=(wx, wy)) + _overlap_add_tile_numba(filtered, tile, taper, beg0, beg1, n0, n1) + + +# fastmath is intentional here: the weighting is approximate and tests allow +# small SciPy/Numba differences from parallel floating-point evaluation. +_process_tile_group_numba = nb.njit(cache=True, fastmath=True, parallel=True)( + _process_tile_group_numba_impl +) + + +def _adaptive_spectral_filter_numba( + data: np.ndarray, + *, + window_size: tuple[int, int], + overlap: tuple[int, int], + exponent: float = 0.3, + normalize_power: bool = False, +) -> np.ndarray: + """ + Filter a 2D array with the optional Numba/rocket-fft implementation. + + Parameters + ---------- + data + Two-dimensional input array. The filter computes in ``float32``. + window_size + Two power-of-two window lengths, one per array axis. Values must be + greater than 4. + overlap + Number of samples each neighboring window overlaps on each axis. Each + value must be non-negative and smaller than half the matching window. + exponent + Spectral magnitude exponent used as the adaptive weighting power. ``0`` + leaves the spectrum unweighted before overlap-add reconstruction. + normalize_power + If ``True``, normalize each tile's spectral magnitudes by that tile's + maximum magnitude before applying ``exponent``. + + Returns + ------- + numpy.ndarray + The filtered array with the same shape as ``data``. Floating input + dtypes are restored; non-floating inputs return ``float32`` output. + + Raises + ------ + ValueError + If ``data`` is not two-dimensional, ``exponent`` is not finite, + ``window_size`` and ``overlap`` do not contain exactly two integer + values, any window size is not a power of two greater than 4, or any + overlap is negative or at least half the matching window size. + + Notes + ----- + This implementation uses Numba-compiled loops and rocket-fft-backed NumPy + FFT calls. It is selected by + :func:`dascore.proc.adaptive_spectral_filter.adaptive_spectral_filter` for + two selected dimensions when ``engine="numba"`` or when ``engine="auto"`` + and optional dependencies are installed. + """ + data = np.asarray(data) + _validate_filter_inputs( + data, window_size=window_size, overlap=overlap, exponent=float(exponent) + ) + wx, wy = window_size + working, original_dtype, stride, taper, padded, filtered, n_tiles = ( + _prepare_work_arrays(data, window_size=window_size, overlap=overlap) + ) + for parity0 in range(2): + for parity1 in range(2): + _process_tile_group_numba( + padded, + filtered, + taper, + wx, + wy, + stride[0], + stride[1], + n_tiles[0], + n_tiles[1], + parity0, + parity1, + float(exponent), + bool(normalize_power), + ) + return _finalize_output(filtered, working, original_dtype, stride) diff --git a/dascore/proc/adaptive_spectral_filter.py b/dascore/proc/adaptive_spectral_filter.py new file mode 100644 index 000000000..73f5ee71d --- /dev/null +++ b/dascore/proc/adaptive_spectral_filter.py @@ -0,0 +1,487 @@ +""" +Adaptive spectral filtering for DASCore patches. + +The adaptive spectral filter suppresses incoherent energy by processing a patch +in overlapping windows along one or two selected dimensions. Each window is +transformed to the spectral domain, weighted by a power of its spectral +magnitude, transformed back to the original domain, and accumulated with +tapered overlap-add reconstruction. + +With one selected dimension, this is an adaptive frequency-domain normalization +applied independently to every trace over the remaining patch dimensions. With +two selected dimensions, this is the adaptive frequency-wavenumber filter +described by @isken2022denoising and exposed by Pyrocko +[Lightguide](https://github.com/pyrocko/lightguide). Coherent plane-wave energy +tends to concentrate in the frequency-wavenumber spectrum, so the weighting +emphasizes locally coherent arrivals relative to diffuse or randomly +distributed energy. + +This module exposes a single public patch method, +:func:`adaptive_spectral_filter`. The public function resolves one or two +DASCore dimensions, converts window and overlap values to sample counts, moves +those dimensions to the array tail, and processes every remaining leading index +as an independent batch. The lower-level SciPy and Numba implementations are +private because they operate on raw arrays and do not perform DASCore +coordinate handling. + +The SciPy engine handles one- and two-dimensional selected windows using +``rfftn``/``irfftn``. The optional Numba/rocket-fft engine currently handles +the two-dimensional case only, using parity-separated tile groups so neighboring +writes do not overlap within each parallel loop. Both engines share validation, +padding, tapering, and dtype-restoration logic so two-dimensional outputs remain +directly comparable. +""" + +from __future__ import annotations + +from collections.abc import Callable, Mapping +from itertools import product +from math import prod +from typing import Any, Literal + +import numpy as np +from scipy import fft as sp_fft + +from dascore.constants import PatchType +from dascore.exceptions import MissingOptionalDependencyError, ParameterError +from dascore.utils.patch import get_dim_axis_value, patch_function +from dascore.utils.signal import _triangular_taper + +_AdaptiveSpectralEngine = Literal["auto", "numba", "scipy"] +__all__ = ("adaptive_spectral_filter",) + + +def _is_power_of_two(value: int) -> bool: + """Return ``True`` when *value* is a positive power of two.""" + return value > 0 and (value & (value - 1) == 0) + + +def _validate_filter_inputs( + data: np.ndarray, + *, + window_size: tuple[int, ...], + overlap: tuple[int, ...], + exponent: float, +) -> None: + """Validate direct array-filter inputs before entering FFT kernels.""" + if data.ndim not in {1, 2}: + msg = ( + f"adaptive spectral array filters require 1D or 2D input; got {data.ndim}D." + ) + raise ValueError(msg) + if len(window_size) != data.ndim or len(overlap) != data.ndim: + msg = "window_size and overlap must match the input dimensionality." + raise ValueError(msg) + if not np.isfinite(exponent): + msg = "exponent must be finite." + raise ValueError(msg) + + for axis, (window, axis_overlap) in enumerate(zip(window_size, overlap)): + if not isinstance(window, (int, np.integer)): + msg = f"window_size[{axis}] must be an integer; got {window!r}." + raise ValueError(msg) + if not isinstance(axis_overlap, (int, np.integer)): + msg = f"overlap[{axis}] must be an integer; got {axis_overlap!r}." + raise ValueError(msg) + + window = int(window) + axis_overlap = int(axis_overlap) + if window <= 4 or not _is_power_of_two(window): + msg = ( + f"window_size[{axis}] must be a power of two greater than 4; " + f"got {window!r}." + ) + raise ValueError(msg) + if axis_overlap < 0: + msg = f"overlap[{axis}] must be non-negative; got {axis_overlap!r}." + raise ValueError(msg) + if axis_overlap >= window / 2: + msg = f"overlap[{axis}] is too large; maximum is {window // 2 - 1} samples." + raise ValueError(msg) + + +def _prepare_work_arrays( + data: np.ndarray, + *, + window_size: tuple[int, ...], + overlap: tuple[int, ...], +) -> tuple[ + np.ndarray, + np.dtype, + tuple[int, ...], + np.ndarray, + np.ndarray, + np.ndarray, + tuple[int, ...], +]: + """Prepare ``float32`` padded arrays shared by filter implementations.""" + data = np.asarray(data) + original_dtype = data.dtype + working = np.ascontiguousarray(data, dtype=np.float32) + stride = tuple(win - over for win, over in zip(window_size, overlap)) + plateau = tuple(win - 2 * over for win, over in zip(window_size, overlap)) + taper = _triangular_taper(window_size, plateau) + + padded_shape = tuple( + length + 2 * step for length, step in zip(working.shape, stride) + ) + padded = np.zeros(padded_shape, dtype=np.float32) + inner_slices = tuple( + slice(step, length + step) for length, step in zip(working.shape, stride) + ) + padded[inner_slices] = working + filtered = np.zeros_like(padded) + n_tiles = tuple(pad_len // step for pad_len, step in zip(padded.shape, stride)) + return working, original_dtype, stride, taper, padded, filtered, n_tiles + + +def _finalize_output( + filtered: np.ndarray, + working: np.ndarray, + original_dtype: np.dtype, + stride: tuple[int, ...], +) -> np.ndarray: + """Crop padded output and restore floating dtypes where possible.""" + slices = tuple( + slice(step, length + step) for length, step in zip(working.shape, stride) + ) + out = filtered[slices] + if np.issubdtype(original_dtype, np.floating): + return out.astype(original_dtype, copy=False) + return out + + +def _extract_tiles_python( + padded: np.ndarray, + window_size: tuple[int, ...], + stride: tuple[int, ...], + n_tiles: tuple[int, ...], +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extract padded windows into a dense tile stack for batched SciPy FFTs.""" + ndim = len(window_size) + tiles = np.zeros((prod(n_tiles), *window_size), dtype=np.float32) + begins = np.zeros((*n_tiles, ndim), dtype=np.int64) + sizes = np.zeros((*n_tiles, ndim), dtype=np.int64) + for tile_index, tile_inds in enumerate(product(*(range(num) for num in n_tiles))): + beg = tuple(ind * step for ind, step in zip(tile_inds, stride)) + end = tuple( + min(start + win, size) + for start, win, size in zip(beg, window_size, padded.shape) + ) + valid_shape = tuple(stop - start for start, stop in zip(beg, end)) + begins[tile_inds] = beg + sizes[tile_inds] = valid_shape + data_slices = tuple(slice(start, stop) for start, stop in zip(beg, end)) + tile_slices = tuple(slice(0, size) for size in valid_shape) + tiles[(tile_index, *tile_slices)] = padded[data_slices] + return tiles, begins, sizes + + +def _overlap_add_tiles_python( + out: np.ndarray, + tiles: np.ndarray, + taper: np.ndarray, + begins: np.ndarray, + sizes: np.ndarray, +) -> None: + """Apply tapered overlap-add reconstruction from a dense tile stack.""" + grid_shape = begins.shape[:-1] + for tile_index, tile_inds in enumerate( + product(*(range(num) for num in grid_shape)) + ): + beg = tuple(begins[tile_inds]) + valid_shape = tuple(sizes[tile_inds]) + out_slices = tuple( + slice(start, start + size) for start, size in zip(beg, valid_shape) + ) + tile_slices = tuple(slice(0, size) for size in valid_shape) + out[out_slices] += tiles[(tile_index, *tile_slices)] * taper[tile_slices] + + +def _adaptive_spectral_filter_scipy( + data: np.ndarray, + *, + window_size: tuple[int, ...], + overlap: tuple[int, ...], + exponent: float = 0.3, + normalize_power: bool = False, +) -> np.ndarray: + """ + Filter a 1D or 2D array with the SciPy adaptive spectral implementation. + + Parameters + ---------- + data + One- or two-dimensional input array. The filter computes in ``float32``. + window_size + Power-of-two window lengths, one per array axis. Values must be greater + than 4. + overlap + Number of samples each neighboring window overlaps on each axis. Values + must be non-negative and smaller than half the matching window. + exponent + Spectral magnitude exponent used as the adaptive weighting power. ``0`` + leaves the spectrum unweighted before overlap-add reconstruction. + normalize_power + If ``True``, normalize each tile's spectral magnitudes by that tile's + maximum magnitude before applying ``exponent``. + + Returns + ------- + numpy.ndarray + The filtered array with the same shape as ``data``. Floating input + dtypes are restored; non-floating inputs return ``float32`` output. + + Raises + ------ + ValueError + If ``data`` is not one- or two-dimensional, ``exponent`` is not finite, + ``window_size`` and ``overlap`` do not match ``data.ndim``, any window + size is not a power of two greater than 4, or any overlap is negative or + at least half the matching window size. + """ + data = np.asarray(data) + _validate_filter_inputs( + data, window_size=window_size, overlap=overlap, exponent=float(exponent) + ) + working, original_dtype, stride, taper, padded, filtered, n_tiles = ( + _prepare_work_arrays(data, window_size=window_size, overlap=overlap) + ) + tiles, begins, sizes = _extract_tiles_python(padded, window_size, stride, n_tiles) + axes = tuple(range(-data.ndim, 0)) + + spec = sp_fft.rfftn(tiles, s=window_size, axes=axes, workers=-1) + if exponent != 0.0: + power = np.abs(spec).astype(np.float32, copy=False) + if normalize_power: + max_power = power.max(axis=axes, keepdims=True) + power = np.divide( + power, max_power, out=np.zeros_like(power), where=max_power != 0 + ) + spec *= power**exponent + tiles = sp_fft.irfftn(spec, s=window_size, axes=axes, workers=-1).astype( + np.float32, copy=False + ) + _overlap_add_tiles_python(filtered, tiles, taper, begins, sizes) + return _finalize_output(filtered, working, original_dtype, stride) + + +def _get_dim_axis_values(patch: PatchType, kwargs: Mapping[str, Any]): + """Resolve DASCore dimension keyword arguments into dim/axis values.""" + if len(kwargs) not in {1, 2}: + msg = ( + "adaptive_spectral_filter requires one or two dimension window kwargs, " + "e.g. patch.adaptive_spectral_filter(time=32, samples=True)." + ) + raise ParameterError(msg) + return get_dim_axis_value(patch, kwargs=dict(kwargs), allow_multiple=True) + + +def _dim_values_to_samples( + patch: PatchType, + dim_axis_values, + *, + samples: bool, + name: str, + force_sample_dims: frozenset[str] = frozenset(), +) -> tuple[int, ...]: + """Convert DASCore dimension values from units or samples into sample counts.""" + out: list[int] = [] + for dim, _, value in dim_axis_values: + if samples or dim in force_sample_dims: + count = int(value) + else: + coord = patch.get_coord(dim, require_evenly_sampled=True) + count = coord.get_sample_count(value, samples=False) + invalid = count < 0 if name == "overlap" else count <= 0 + if invalid: + requirement = "non-negative" if name == "overlap" else "positive" + msg = f"{name} for dimension {dim!r} must be {requirement}." + raise ParameterError(msg) + out.append(count) + return tuple(out) + + +def _normalize_overlap( + overlap: int | Mapping[str, Any] | None, + dims: tuple[str, ...], + windows: tuple[int, ...], +) -> tuple[dict[str, Any], frozenset[str]]: + """Return per-dimension overlap values and internally defaulted dimensions.""" + defaults = {dim: max(window // 2 - 2, 0) for dim, window in zip(dims, windows)} + if overlap is None: + return defaults, frozenset(dims) + if isinstance(overlap, Mapping): + extra = set(overlap) - set(dims) + if extra: + msg = f"overlap contains dimensions not being filtered: {sorted(extra)}" + raise ParameterError(msg) + return defaults | dict(overlap), frozenset(set(dims) - set(overlap)) + return {dim: int(overlap) for dim in dims}, frozenset() + + +def _validate_window_and_overlap( + dims: tuple[str, ...], + windows: tuple[int, ...], + overlaps: tuple[int, ...], + exponent: float, +) -> None: + """Validate public DASCore window and overlap settings.""" + if not np.isfinite(exponent): + msg = "exponent must be finite." + raise ParameterError(msg) + for dim, window, overlap in zip(dims, windows, overlaps): + if window <= 4 or not _is_power_of_two(window): + msg = f"window size for {dim!r} must be a power of two and > 4." + raise ParameterError(msg) + if overlap < 0: + msg = f"overlap for {dim!r} must be non-negative." + raise ParameterError(msg) + if overlap >= window / 2: + msg = ( + f"overlap for {dim!r} is too large. Maximum overlap is " + f"{window // 2 - 1} samples." + ) + raise ParameterError(msg) + + +def _get_engine(engine: _AdaptiveSpectralEngine, selected_ndim: int) -> Callable: + """Return the requested adaptive spectral array filter implementation.""" + if engine == "scipy" or (engine == "auto" and selected_ndim == 1): + return _adaptive_spectral_filter_scipy + if engine not in {"auto", "numba"}: + msg = "engine must be one of 'auto', 'numba', or 'scipy'." + raise ParameterError(msg) + if selected_ndim != 2: + msg = "engine='numba' currently supports exactly two selected dimensions." + raise ParameterError(msg) + try: + from dascore.proc._adaptive_spectral_filter_numba import ( + _adaptive_spectral_filter_numba, + ) + except ImportError as exc: + if engine == "numba": + msg = ( + "engine='numba' requires optional dependencies numba and " + "rocket-fft to be installed." + ) + raise MissingOptionalDependencyError(msg) from exc + return _adaptive_spectral_filter_scipy + return _adaptive_spectral_filter_numba + + +@patch_function() +def adaptive_spectral_filter( + patch: PatchType, + *, + overlap: int | Mapping[str, Any] | None = None, + exponent: float = 0.3, + normalize_power: bool = False, + samples: bool = False, + engine: _AdaptiveSpectralEngine = "auto", + **kwargs: Any, +) -> PatchType: + """ + Apply adaptive spectral filtering over one or two patch dimensions. + + Parameters + ---------- + patch + DASCore patch whose data should be filtered. + overlap + Window overlap in samples when ``samples=True`` or in coordinate units + otherwise. A single value applies to all selected dimensions; a mapping + can specify dimensions independently. When omitted, each dimension + defaults to ``window // 2 - 2`` samples. + exponent + Spectral magnitude exponent used as the adaptive weighting power. ``0`` + leaves the spectrum unweighted before overlap-add reconstruction. + normalize_power + If ``True``, normalize each tile's spectral magnitudes by that tile's + maximum magnitude before applying ``exponent``. + samples + If ``True``, dimension kwargs and overlap values are interpreted as + sample counts. If ``False``, values are converted through evenly sampled + patch coordinates. + engine + ``"auto"`` uses SciPy for one selected dimension and the optional + Numba/rocket-fft implementation for two selected dimensions when + available. ``"numba"`` requires two selected dimensions and the optional + fast engine. ``"scipy"`` always uses the SciPy FFT implementation. + **kwargs + One or two dimension names and their window sizes, such as ``time=32`` + or ``time=32, distance=32``. + + Returns + ------- + Patch + A new patch with filtered data and original dimensions and coordinates. + + Raises + ------ + ParameterError + If one or two dimensions are not selected, if selected window or overlap + values are invalid, if ``exponent`` is not finite, or if an invalid + engine name is requested. + MissingOptionalDependencyError + If ``engine="numba"`` is requested for two selected dimensions but the + optional fast-engine dependencies are not installed. + + Examples + -------- + >>> import dascore as dc + >>> patch = dc.get_example_patch() + >>> filtered_1d = patch.adaptive_spectral_filter(time=32, samples=True) + >>> filtered_2d = patch.adaptive_spectral_filter( + ... time=32, distance=32, samples=True + ... ) + >>> filtered_1d.shape == filtered_2d.shape == patch.shape + True + + Notes + ----- + With two selected dimensions, this method is equivalent to the adaptive + frequency-wavenumber (f-k) filter described in @isken2022denoising and + follows the behavior exposed by Pyrocko + [Lightguide](https://github.com/pyrocko/lightguide). + """ + dim_axis_values = _get_dim_axis_values(patch, kwargs) + dims = tuple(x.dim for x in dim_axis_values) + axes = tuple(x.axis for x in dim_axis_values) + windows = _dim_values_to_samples( + patch, dim_axis_values, samples=samples, name="window" + ) + overlap_values, default_overlap_dims = _normalize_overlap(overlap, dims, windows) + overlap_dim_axis_values = get_dim_axis_value( + patch, kwargs=overlap_values, allow_multiple=True + ) + overlaps = _dim_values_to_samples( + patch, + samples=samples, + dim_axis_values=overlap_dim_axis_values, + name="overlap", + force_sample_dims=default_overlap_dims, + ) + _validate_window_and_overlap(dims, windows, overlaps, float(exponent)) + + data = np.asarray(patch.data) + selected_ndim = len(axes) + moved = np.moveaxis(data, axes, tuple(range(-selected_ndim, 0))) + batch_shape = moved.shape[:-selected_ndim] + selected_shape = moved.shape[-selected_ndim:] + working = moved.reshape((-1, *selected_shape)) + filtered = np.empty_like(working, dtype=np.float32) + engine_func = _get_engine(engine, selected_ndim) + for ind, array in enumerate(working): + filtered[ind] = engine_func( + array, + window_size=windows, + overlap=overlaps, + exponent=float(exponent), + normalize_power=bool(normalize_power), + ) + filtered = filtered.reshape((*batch_shape, *selected_shape)) + filtered = np.moveaxis(filtered, tuple(range(-selected_ndim, 0)), axes) + if np.issubdtype(data.dtype, np.floating): + filtered = filtered.astype(data.dtype, copy=False) + return patch.update(data=filtered) diff --git a/dascore/utils/signal.py b/dascore/utils/signal.py index 0d2bbddfe..a2837b20e 100644 --- a/dascore/utils/signal.py +++ b/dascore/utils/signal.py @@ -2,6 +2,9 @@ Utilities for signal processing. """ +from functools import lru_cache + +import numpy as np from scipy.signal import windows from dascore.exceptions import ParameterError @@ -34,3 +37,86 @@ def _get_window_function(window_type): raise ParameterError(msg) func = WINDOW_FUNCTIONS[window_type] return func + + +def _triangular_taper_1d(size: int, plateau: int) -> np.ndarray: + """Return a one-dimensional triangular plateau taper.""" + ramp_size = (size - plateau) // 2 + taper = np.ones(size, dtype=np.float32) + if ramp_size: + ramp = _get_window_function("triang")(ramp_size * 2 + 1)[:ramp_size] + taper[:ramp_size] = ramp + taper[size - ramp_size :] = ramp[::-1] + return taper + + +@lru_cache(maxsize=64) +def _cached_triangular_taper( + window_size: tuple[int, ...], plateau: tuple[int, ...] +) -> np.ndarray: + """Build the cached taper array used by :func:`_triangular_taper`.""" + if len(window_size) != len(plateau): + msg = "window_size and plateau must have the same length." + raise ValueError(msg) + if len(window_size) not in {1, 2}: + msg = "Only one- and two-dimensional tapers are supported." + raise ValueError(msg) + if any(plat > win for win, plat in zip(window_size, plateau)): + msg = "Plateau cannot be larger than window size." + raise ValueError(msg) + if any(plat < 0 for plat in plateau): + msg = "Plateau sizes must be non-negative." + raise ValueError(msg) + if any(win % 2 for win in window_size): + msg = "Window sizes must be even." + raise ValueError(msg) + tapers = [ + _triangular_taper_1d(win, plat) for win, plat in zip(window_size, plateau) + ] + if len(tapers) == 1: + return tapers[0].astype(np.float32) + return (tapers[0][:, None] * tapers[1][None, :]).astype(np.float32) + + +def _triangular_taper( + window_size: tuple[int, ...], plateau: tuple[int, ...] +) -> np.ndarray: + """ + Return a one- or two-dimensional triangular plateau taper. + + Parameters + ---------- + window_size + Number of samples in the window dimensions. Values must be even. + plateau + Number of central samples with unit weight in each dimension. Values + must be non-negative and no larger than the corresponding window size. + + Returns + ------- + numpy.ndarray + A ``float32`` array with shape ``window_size``. The returned array is a + copy, so callers can mutate it without corrupting the internal cache. + + Raises + ------ + ValueError + If the inputs have different lengths, if the taper dimensionality is + not one or two, if any plateau is negative, if any plateau is greater + than the corresponding window size, or if any window size is odd. + + Notes + ----- + The taper is separable in 2D. Each axis contains a central unit-weight + plateau and triangular ramps on both sides generated from DASCore's + registered ``"triang"`` window function. When plateau equals window size, + the result is all ones along that axis. + + Examples + -------- + >>> from dascore.utils.signal import _triangular_taper + >>> taper = _triangular_taper((8, 8), (2, 2)) + >>> taper.shape + (8, 8) + """ + return _cached_triangular_taper(window_size, plateau).copy() diff --git a/docs/references.bib b/docs/references.bib index b633bf422..2a4428042 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -67,3 +67,12 @@ @article{schimmel1997noise year={1997}, publisher={Blackwell Publishing Ltd Oxford, UK} } + +@article{isken2022denoising, + title={De-noising distributed acoustic sensing data using an adaptive frequency-wavenumber filter}, + author={Isken, Marius Paul and Vasyura-Bathke, Hannes and Dahm, Torsten and Heimann, Sebastian}, + journal={Geophysical Journal International}, + pages={ggac229}, + year={2022}, + doi={10.1093/gji/ggac229} +} diff --git a/pyproject.toml b/pyproject.toml index 764a87bf8..9c3eb2fef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ extras = [ "findiff", "obspy", "numba", + "rocket-fft", "segyio", "bottleneck", ] diff --git a/tests/test_proc/test_adaptive_spectral_filter.py b/tests/test_proc/test_adaptive_spectral_filter.py new file mode 100644 index 000000000..b91948bd6 --- /dev/null +++ b/tests/test_proc/test_adaptive_spectral_filter.py @@ -0,0 +1,783 @@ +"""Tests for adaptive spectral filtering.""" + +from __future__ import annotations + +import builtins +from typing import Any + +import numpy as np +import pytest + +import dascore as dc +from dascore.exceptions import ( + CoordError, + MissingOptionalDependencyError, + ParameterError, + PatchCoordinateError, +) +from dascore.proc import adaptive_spectral_filter as adaptive_spectral_filter_func +from dascore.proc.adaptive_spectral_filter import ( + _adaptive_spectral_filter_scipy, + _get_engine, + _validate_window_and_overlap, +) +from dascore.utils.signal import _triangular_taper + + +def _patch( + shape: tuple[int, ...], + dims: tuple[str, ...], + *, + dtype=np.float32, + time_step=np.timedelta64(4, "ms"), + distance_step=1.0, +) -> dc.Patch: + """Return a deterministic patch for adaptive spectral tests.""" + rng = np.random.default_rng(20260508) + data = rng.normal(size=shape).astype(dtype) + coords = {} + for dim, length in zip(dims, shape, strict=True): + if dim == "time": + coords[dim] = np.datetime64("2020-01-01") + np.arange(length) * time_step + elif dim == "distance": + coords[dim] = np.arange(length, dtype=float) * distance_step + else: + coords[dim] = np.arange(length, dtype=float) + return dc.Patch(data=data, coords=coords, dims=dims) + + +class TestAdaptiveSpectralFilter: + """Tests for the adaptive spectral patch method.""" + + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_dtype_shape_dims_and_coords_preserved(self, dtype) -> None: + """Adaptive spectral should preserve patch structure and floating dtype.""" + patch = _patch((64, 64), ("distance", "time"), dtype=dtype) + + out = patch.adaptive_spectral_filter( + distance=16, + time=16, + overlap={"distance": 7, "time": 7}, + samples=True, + engine="scipy", + ) + + assert np.asarray(out.data).dtype == np.asarray(patch.data).dtype + assert out.shape == patch.shape + assert out.dims == patch.dims + assert out.coords == patch.coords + assert out.attrs.history[-1].startswith("adaptive_spectral_filter") + + def test_time_distance_reversed_dims_are_supported(self) -> None: + """Selected dimensions need not be in a fixed order.""" + patch = _patch((64, 80), ("time", "distance"), dtype=np.float32) + + out = patch.adaptive_spectral_filter( + time=16, + distance=32, + overlap={"time": 7, "distance": 14}, + samples=True, + engine="scipy", + ) + + assert out.shape == patch.shape + assert out.dims == ("time", "distance") + assert np.isfinite(out.data).all() + + def test_arbitrary_2d_dimension_names_are_supported(self) -> None: + """Adaptive spectral should work with any two patch dimensions.""" + patch = _patch((64, 64), ("channel", "sample"), dtype=np.float32) + + out = patch.adaptive_spectral_filter( + channel=16, + sample=16, + overlap={"channel": 7, "sample": 7}, + samples=True, + engine="scipy", + ) + + assert out.shape == patch.shape + assert out.dims == ("channel", "sample") + + def test_requires_explicit_dimension_kwargs(self) -> None: + """At least one dimension window is required.""" + patch = _patch((64, 64), ("channel", "sample"), dtype=np.float32) + + with pytest.raises(ParameterError, match="one or two dimension window kwargs"): + patch.adaptive_spectral_filter(samples=True, engine="scipy") + + def test_rejects_non_positive_window(self) -> None: + """Window sizes must resolve to positive sample counts.""" + patch = _patch((64, 64), ("distance", "time"), dtype=np.float32) + + with pytest.raises(ParameterError, match=r"window.*must be positive"): + patch.adaptive_spectral_filter( + distance=0, time=16, samples=True, engine="scipy" + ) + + def test_one_dimension_filter_is_supported(self) -> None: + """A single selected dimension should run the 1D spectral path.""" + patch = _patch((8, 64), ("distance", "time"), dtype=np.float32) + + out = patch.adaptive_spectral_filter( + time=16, + overlap={"time": 7}, + samples=True, + engine="scipy", + ) + + assert out.shape == patch.shape + assert out.dims == patch.dims + assert out.coords == patch.coords + assert np.isfinite(out.data).all() + + def test_one_dimension_filter_batches_other_dims(self) -> None: + """Unselected dimensions should be batches for 1D filtering.""" + patch = _patch((3, 8, 64), ("shot", "distance", "time"), dtype=np.float32) + + out = patch.adaptive_spectral_filter(time=16, samples=True, engine="scipy") + expected = np.stack( + [ + dc.Patch( + data=np.asarray(patch.data)[ind], + coords={ + "distance": patch.get_array("distance"), + "time": patch.get_array("time"), + }, + dims=("distance", "time"), + ) + .adaptive_spectral_filter(time=16, samples=True, engine="scipy") + .data + for ind in range(patch.shape[0]) + ] + ) + + np.testing.assert_allclose(out.data, expected, rtol=1e-5, atol=1e-5) + + def test_coordinate_unit_window_and_overlap_conversion(self) -> None: + """Coordinate units and sample counts should resolve identically.""" + patch = _patch( + (64, 64), + ("distance", "time"), + dtype=np.float32, + distance_step=2.0, + time_step=np.timedelta64(4, "ms"), + ) + + by_units = patch.adaptive_spectral_filter( + distance=32.0, + time=np.timedelta64(64, "ms"), + overlap={"distance": 14.0, "time": np.timedelta64(28, "ms")}, + samples=False, + engine="scipy", + ) + by_samples = patch.adaptive_spectral_filter( + distance=16, + time=16, + overlap={"distance": 7, "time": 7}, + samples=True, + engine="scipy", + ) + + np.testing.assert_allclose(by_units.data, by_samples.data, rtol=1e-5, atol=1e-5) + + def test_default_overlap_stays_in_samples_when_windows_use_units(self) -> None: + """Computed overlap defaults should not be interpreted as coordinate units.""" + patch = _patch( + (64, 64), + ("distance", "time"), + dtype=np.float32, + distance_step=2.0, + time_step=np.timedelta64(4, "ms"), + ) + + by_units = patch.adaptive_spectral_filter( + distance=32.0, + time=np.timedelta64(64, "ms"), + samples=False, + engine="scipy", + ) + by_samples = patch.adaptive_spectral_filter( + distance=16, + time=16, + overlap={"distance": 6, "time": 6}, + samples=True, + engine="scipy", + ) + + np.testing.assert_allclose(by_units.data, by_samples.data, rtol=1e-5, atol=1e-5) + + def test_partial_overlap_defaults_stay_in_samples_with_units(self) -> None: + """Missing overlap mapping entries should stay sample-count defaults.""" + patch = _patch( + (64, 64), + ("distance", "time"), + dtype=np.float32, + distance_step=2.0, + time_step=np.timedelta64(4, "ms"), + ) + + by_units = patch.adaptive_spectral_filter( + distance=32.0, + time=np.timedelta64(64, "ms"), + overlap={"time": np.timedelta64(24, "ms")}, + samples=False, + engine="scipy", + ) + by_samples = patch.adaptive_spectral_filter( + distance=16, + time=16, + overlap={"distance": 6, "time": 6}, + samples=True, + engine="scipy", + ) + + np.testing.assert_allclose(by_units.data, by_samples.data, rtol=1e-5, atol=1e-5) + + def test_1d_default_overlap_stays_in_samples_with_units(self) -> None: + """Computed 1D overlap defaults should stay in sample counts.""" + patch = _patch( + (8, 64), + ("distance", "time"), + dtype=np.float32, + time_step=np.timedelta64(4, "ms"), + ) + + by_units = patch.adaptive_spectral_filter( + time=np.timedelta64(64, "ms"), + samples=False, + engine="scipy", + ) + by_samples = patch.adaptive_spectral_filter( + time=16, + overlap=6, + samples=True, + engine="scipy", + ) + + np.testing.assert_allclose(by_units.data, by_samples.data, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize( + "shape,dims,kwargs", + [ + ((3, 64, 64), ("shot", "distance", "time"), {"distance": 16, "time": 16}), + ( + (2, 3, 64, 64), + ("component", "shot", "distance", "time"), + {"distance": 16, "time": 16}, + ), + ((64, 3, 64), ("distance", "shot", "time"), {"distance": 16, "time": 16}), + ], + ) + def test_batches_over_non_selected_dimensions( + self, + shape: tuple[int, ...], + dims: tuple[str, ...], + kwargs: dict[str, int], + ) -> None: + """Extra dimensions should be processed as independent 2D batches.""" + patch = _patch(shape, dims, dtype=np.float32) + + out = patch.adaptive_spectral_filter( + **kwargs, + overlap={dim: max(value // 2 - 2, 0) for dim, value in kwargs.items()}, + samples=True, + engine="scipy", + ) + + assert out.shape == patch.shape + assert out.dims == patch.dims + assert out.coords == patch.coords + assert np.isfinite(out.data).all() + + def test_batched_output_matches_independent_2d_calls(self) -> None: + """Batched filtering should match independent 2D patch calls.""" + patch = _patch((4, 32, 32), ("depth", "distance", "time"), dtype=np.float32) + + out = patch.adaptive_spectral_filter( + distance=16, time=16, samples=True, engine="scipy" + ) + expected = np.stack( + [ + dc.Patch( + data=np.asarray(patch.data)[ind], + coords={ + "distance": patch.get_array("distance"), + "time": patch.get_array("time"), + }, + dims=("distance", "time"), + ) + .adaptive_spectral_filter( + distance=16, time=16, samples=True, engine="scipy" + ) + .data + for ind in range(patch.shape[0]) + ] + ) + + np.testing.assert_allclose(out.data, expected, rtol=1e-5, atol=1e-5) + + def test_rejects_unknown_overlap_dimension(self) -> None: + """Overlap mappings may only name selected dimensions.""" + patch = _patch((64, 64), ("distance", "time"), dtype=np.float32) + + with pytest.raises(ParameterError, match="overlap contains dimensions"): + patch.adaptive_spectral_filter( + distance=16, + time=16, + overlap={"distance": 7, "bad": 7}, + samples=True, + engine="scipy", + ) + + def test_scalar_overlap_applies_to_both_dimensions(self) -> None: + """A scalar overlap should apply to each selected dimension.""" + patch = _patch((64, 64), ("distance", "time"), dtype=np.float32) + + out = patch.adaptive_spectral_filter( + distance=16, + time=16, + overlap=7, + samples=True, + engine="scipy", + ) + + assert out.shape == patch.shape + + def test_scalar_overlap_applies_to_one_dimension(self) -> None: + """A scalar overlap should also work for 1D filtering.""" + patch = _patch((8, 64), ("distance", "time"), dtype=np.float32) + + out = patch.adaptive_spectral_filter( + time=16, + overlap=7, + samples=True, + engine="scipy", + ) + + assert out.shape == patch.shape + + def test_zero_overlap_is_supported(self) -> None: + """Zero overlap should use an all-ones reconstruction taper.""" + patch = _patch((64, 64), ("distance", "time"), dtype=np.float32) + + out = patch.adaptive_spectral_filter( + distance=16, + time=16, + overlap=0, + samples=True, + engine="scipy", + ) + + assert out.shape == patch.shape + assert np.isfinite(out.data).all() + + def test_one_dimension_auto_uses_scipy(self) -> None: + """Auto mode should use SciPy for 1D filtering.""" + patch = _patch((8, 64), ("distance", "time"), dtype=np.float32) + + out = patch.adaptive_spectral_filter(time=16, samples=True, engine="auto") + + assert out.shape == patch.shape + + def test_numba_rejects_one_dimension(self) -> None: + """The optional numba engine is intentionally 2D-only.""" + patch = _patch((8, 64), ("distance", "time"), dtype=np.float32) + + with pytest.raises(ParameterError, match="two selected dimensions"): + patch.adaptive_spectral_filter(time=16, samples=True, engine="numba") + + @pytest.mark.parametrize( + "kwargs,match", + [ + ({"exponent": np.nan}, "exponent must be finite"), + ({"distance": 15}, "power of two"), + ({"overlap": {"distance": 8}}, "too large"), + ({"overlap": {"distance": -1}}, "non-negative"), + ], + ) + def test_patch_validation_branches( + self, kwargs: dict[str, Any], match: str + ) -> None: + """Patch-level validation should raise ParameterError.""" + patch = _patch((64, 64), ("distance", "time"), dtype=np.float32) + call_kwargs = {"distance": 16, "time": 16, "samples": True, "engine": "scipy"} + call_kwargs.update(kwargs) + + with pytest.raises(ParameterError, match=match): + patch.adaptive_spectral_filter(**call_kwargs) + + def test_rejects_missing_dimension_kwarg(self) -> None: + """Unknown dimensions should raise the normal patch coordinate error.""" + patch = _patch((64, 64), ("distance", "time"), dtype=np.float32) + + with pytest.raises(PatchCoordinateError, match="not found"): + patch.adaptive_spectral_filter( + distance=16, missing=16, samples=True, engine="scipy" + ) + + def test_uneven_coordinate_conversion_raises_when_not_samples(self) -> None: + """Coordinate-unit windows require evenly sampled coordinates.""" + patch = dc.get_example_patch("wacky_dim_coords_patch") + + with pytest.raises(CoordError): + patch.adaptive_spectral_filter( + distance=16, time=16, samples=False, engine="scipy" + ) + + def test_nan_values_remain_supported(self) -> None: + """NaNs may propagate but should not produce infinities.""" + patch = dc.get_example_patch("patch_with_null", shape=(64, 64)) + + out = patch.adaptive_spectral_filter( + distance=16, time=16, samples=True, engine="scipy" + ) + out_data = np.asarray(out.data) + + assert out.shape == patch.shape + assert np.isnan(out_data).any() + assert not np.isinf(out_data).any() + + def test_invalid_engine_raises(self) -> None: + """Engine values should be constrained.""" + patch = _patch((64, 64), ("distance", "time"), dtype=np.float32) + + with pytest.raises(ParameterError, match="engine"): + patch.adaptive_spectral_filter( + distance=16, + time=16, + samples=True, + engine="bad", # type: ignore[arg-type] + ) + + def test_proc_export_is_function(self) -> None: + """The processing module should expose the direct patch function.""" + patch = _patch((64, 64), ("distance", "time"), dtype=np.float32) + + out = adaptive_spectral_filter_func( + patch, distance=16, time=16, samples=True, engine="scipy" + ) + + assert out.shape == patch.shape + + +class TestAdaptiveSpectralCore: + """Tests for plain array adaptive spectral helpers.""" + + def test_triangular_taper_values(self) -> None: + """The shared taper should match the overlap-add ramp geometry.""" + taper = _triangular_taper((8, 8), (2, 2)) + expected_1d = np.array([0.25, 0.5, 0.75, 1.0, 1.0, 0.75, 0.5, 0.25]) + + np.testing.assert_allclose(taper, expected_1d[:, None] * expected_1d[None, :]) + assert taper.dtype == np.float32 + + def test_triangular_taper_all_ones_when_plateau_matches_window(self) -> None: + """A full-window plateau should support zero-overlap reconstruction.""" + taper = _triangular_taper((8, 8), (8, 8)) + + np.testing.assert_array_equal(taper, np.ones((8, 8), dtype=np.float32)) + + def test_triangular_taper_one_dimensional(self) -> None: + """The shared taper should also support 1D reconstruction.""" + taper = _triangular_taper((8,), (2,)) + expected = np.array([0.25, 0.5, 0.75, 1.0, 1.0, 0.75, 0.5, 0.25]) + + np.testing.assert_allclose(taper, expected) + assert taper.dtype == np.float32 + + def test_triangular_taper_cache_is_not_mutated_by_callers(self) -> None: + """Callers should receive a copy of the cached taper.""" + taper = _triangular_taper((16, 16), (2, 2)) + expected = taper.copy() + + taper[...] = -1.0 + actual = _triangular_taper((16, 16), (2, 2)) + + np.testing.assert_array_equal(actual, expected) + + @pytest.mark.parametrize( + "window_size,plateau,match", + [ + ((16, 16), (17, 2), "Plateau cannot"), + ((16, 16), (-1, 2), "non-negative"), + ((15, 16), (2, 2), "Window sizes must be even"), + ((16, 16), (2,), "same length"), + ((16, 16, 16), (2, 2, 2), "one- and two-dimensional"), + ], + ) + def test_triangular_taper_rejects_invalid_geometry( + self, + window_size: tuple[int, int], + plateau: tuple[int, int], + match: str, + ) -> None: + """Invalid taper geometry should raise.""" + with pytest.raises(ValueError, match=match): + _triangular_taper(window_size, plateau) + + @pytest.mark.parametrize( + "window_size,overlap,match", + [ + ((15, 16), (7, 7), "power of two"), + ((4, 16), (1, 7), "greater than 4"), + ((16, 16), (-1, 7), "non-negative"), + ((16, 16), (8, 7), "too large"), + ((16.0, 16), (7, 7), "must be an integer"), + ((16, 16), (7.0, 7), "must be an integer"), + ((16,), (7, 7), "match the input dimensionality"), + ], + ) + def test_core_rejects_invalid_window_and_overlap( + self, + window_size: tuple[Any, Any], + overlap: tuple[Any, Any], + match: str, + ) -> None: + """Direct array API should validate window geometry.""" + data = np.ones((32, 32), dtype=np.float32) + + with pytest.raises(ValueError, match=match): + _adaptive_spectral_filter_scipy( + data, window_size=window_size, overlap=overlap + ) + + def test_core_rejects_non_1d_or_2d_input(self) -> None: + """Direct array API is 1D or 2D only.""" + data = np.ones((2, 16, 16), dtype=np.float32) + + with pytest.raises(ValueError, match="1D or 2D input"): + _adaptive_spectral_filter_scipy(data, window_size=(16, 16), overlap=(7, 7)) + + def test_core_rejects_non_finite_exponent(self) -> None: + """Exponent must be finite.""" + data = np.ones((32, 32), dtype=np.float32) + + with pytest.raises(ValueError, match="exponent must be finite"): + _adaptive_spectral_filter_scipy( + data, window_size=(16, 16), overlap=(7, 7), exponent=np.nan + ) + + def test_direct_array_api_returns_float32_for_integer_inputs(self) -> None: + """Non-floating array inputs should return float32 outputs.""" + data = np.ones((32, 32), dtype=np.int16) + + out = _adaptive_spectral_filter_scipy( + data, window_size=(16, 16), overlap=(7, 7) + ) + + assert out.dtype == np.float32 + + def test_direct_array_api_normalizes_power(self) -> None: + """The SciPy path should run power normalization.""" + data = np.ones((32, 32), dtype=np.float32) + + out = _adaptive_spectral_filter_scipy( + data, + window_size=(16, 16), + overlap=(7, 7), + exponent=0.5, + normalize_power=True, + ) + + assert out.shape == data.shape + assert np.isfinite(out).all() + + def test_direct_array_api_filters_one_dimensional_data(self) -> None: + """The SciPy helper should support 1D arrays.""" + data = np.ones(32, dtype=np.float32) + + out = _adaptive_spectral_filter_scipy( + data, + window_size=(16,), + overlap=(7,), + exponent=0.5, + normalize_power=True, + ) + + assert out.shape == data.shape + assert np.isfinite(out).all() + + def test_direct_array_api_supports_zero_overlap(self) -> None: + """The direct SciPy helper should accept non-overlapping windows.""" + data = np.ones((32, 32), dtype=np.float32) + + out = _adaptive_spectral_filter_scipy( + data, window_size=(16, 16), overlap=(0, 0) + ) + + assert out.shape == data.shape + assert np.isfinite(out).all() + + def test_auto_engine_falls_back_when_numba_missing(self, monkeypatch) -> None: + """Auto engine should fall back to SciPy when optional deps are absent.""" + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "dascore.proc._adaptive_spectral_filter_numba": + raise ImportError("simulated missing numba engine") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + assert _get_engine("auto", 2) is _adaptive_spectral_filter_scipy + + def test_numba_engine_raises_when_missing(self, monkeypatch) -> None: + """Explicit numba engine should raise when optional deps are absent.""" + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "dascore.proc._adaptive_spectral_filter_numba": + raise ImportError("simulated missing numba engine") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + with pytest.raises(MissingOptionalDependencyError, match="engine='numba'"): + _get_engine("numba", 2) + + def test_get_engine_uses_scipy_for_one_dimensional_auto(self) -> None: + """Auto mode should use SciPy when one dimension is selected.""" + assert _get_engine("auto", 1) is _adaptive_spectral_filter_scipy + + def test_get_engine_rejects_numba_for_one_dimension(self) -> None: + """Numba is intentionally unavailable for the 1D helper path.""" + with pytest.raises(ParameterError, match="two selected dimensions"): + _get_engine("numba", 1) + + def test_private_window_overlap_validator_rejects_negative_overlap(self) -> None: + """The private window validator should guard negative overlaps.""" + with pytest.raises(ParameterError, match="non-negative"): + _validate_window_and_overlap(("distance", "time"), (16, 16), (-1, 7), 0.3) + + @pytest.mark.parametrize("exponent", [0.0, 0.3]) + def test_numba_and_scipy_match_when_numba_available(self, exponent) -> None: + """The optional Numba 2D path should match the SciPy implementation.""" + numba_mod = pytest.importorskip("dascore.proc._adaptive_spectral_filter_numba") + rng = np.random.default_rng(20260511) + data = rng.normal(size=(32, 32)).astype(np.float32) + + numba = numba_mod._adaptive_spectral_filter_numba( + data, + window_size=(16, 16), + overlap=(7, 7), + exponent=exponent, + normalize_power=True, + ) + scipy = _adaptive_spectral_filter_scipy( + data, + window_size=(16, 16), + overlap=(7, 7), + exponent=exponent, + normalize_power=True, + ) + + np.testing.assert_allclose(scipy, numba, rtol=1e-5, atol=1e-5) + + def test_auto_engine_uses_numba_when_available(self) -> None: + """Auto engine should use the Numba 2D path when optional deps import.""" + numba_mod = pytest.importorskip("dascore.proc._adaptive_spectral_filter_numba") + + assert _get_engine("auto", 2) is numba_mod._adaptive_spectral_filter_numba + + def test_numba_private_helpers_run_in_python(self) -> None: + """The fast-engine helpers should be directly testable in Python.""" + numba_mod = pytest.importorskip("dascore.proc._adaptive_spectral_filter_numba") + padded = np.arange(16, dtype=np.float32).reshape(4, 4) + tile = np.zeros((2, 2), dtype=np.float32) + + assert numba_mod._tile_indices_from_parity_index(3, 2, 1, 0) == (3, 2) + assert numba_mod._tile_bounds(1, 1, 2, 2, 1, 1, 4, 4) == (1, 1, 2, 2) + numba_mod._copy_padded_tile(padded, tile, 1, 1, 2, 2) + np.testing.assert_array_equal(tile, padded[1:3, 1:3]) + assert numba_mod._complex_power(3 + 4j) == np.float32(5.0) + + spec = np.array([[3 + 4j, 0j]], dtype=np.complex64) + assert numba_mod._max_spectral_power(spec) == np.float32(5.0) + assert numba_mod._max_spectral_power_numba_impl(spec) == np.float32(5.0) + weighted = spec.copy() + numba_mod._apply_spectral_weight(weighted, 1.0, False) + np.testing.assert_allclose(weighted[0, 0], spec[0, 0] * 5.0) + + weighted = spec.copy() + numba_mod._apply_spectral_weight(weighted, 0.3, True) + assert np.isfinite(weighted).all() + + weighted = spec.copy() + numba_mod._apply_spectral_weight_numba_impl(weighted, 0.3, True) + assert np.isfinite(weighted).all() + + weighted = spec.copy() + numba_mod._apply_spectral_weight_numba_impl(weighted, 1.0, False) + np.testing.assert_allclose(weighted[0, 0], spec[0, 0] * 5.0) + + zeros = np.array([[0j]], dtype=np.complex64) + numba_mod._apply_spectral_weight(zeros, 0.3, True) + assert zeros[0, 0] == 0j + + zeros = np.array([[0j]], dtype=np.complex64) + numba_mod._apply_spectral_weight_numba_impl(zeros, 0.3, True) + assert zeros[0, 0] == 0j + + filtered = np.zeros_like(padded) + taper = np.ones((2, 2), dtype=np.float32) + numba_mod._overlap_add_tile(filtered, tile, taper, 1, 1, 2, 2) + np.testing.assert_array_equal(filtered[1:3, 1:3], tile) + + def test_numba_private_tile_group_runs_in_python(self) -> None: + """The tile group algorithm should run without JIT for coverage.""" + numba_mod = pytest.importorskip("dascore.proc._adaptive_spectral_filter_numba") + data = np.ones((8, 8), dtype=np.float32) + working, _, stride, taper, padded, filtered, n_tiles = ( + numba_mod._prepare_work_arrays(data, window_size=(8, 8), overlap=(3, 3)) + ) + + numba_mod._process_tile_group_python( + padded, + filtered, + taper, + 8, + 8, + stride[0], + stride[1], + n_tiles[0], + n_tiles[1], + 0, + 0, + 0.0, + False, + ) + numba_mod._process_tile_group_python( + padded, + filtered, + taper, + 8, + 8, + stride[0], + stride[1], + n_tiles[0], + n_tiles[1], + 0, + 0, + 0.5, + True, + ) + numba_mod._process_tile_group_numba_impl( + padded, + filtered, + taper, + 8, + 8, + stride[0], + stride[1], + n_tiles[0], + n_tiles[1], + 0, + 0, + 0.5, + True, + ) + out = numba_mod._finalize_output(filtered, working, data.dtype, stride) + + assert out.shape == data.shape + assert np.isfinite(out).all() diff --git a/tests/test_proc/test_taper.py b/tests/test_proc/test_taper.py index 62ff61527..051335aea 100644 --- a/tests/test_proc/test_taper.py +++ b/tests/test_proc/test_taper.py @@ -245,7 +245,7 @@ def test_poorly_shaped_sequence_raises(self, random_patch): def test_bad_use_of_none(self, random_patch): """Ensure bad use of None raises.""" - with pytest.raises(ParameterError, match="Cannot use ... or None"): + with pytest.raises(ParameterError, match=r"Cannot use \.\.\. or None"): random_patch.taper_range(time=(1, None), relative=True) def test_use_none(self, random_patch):