diff --git a/dascore/constants.py b/dascore/constants.py index 58c15a3ce..2681398da 100644 --- a/dascore/constants.py +++ b/dascore/constants.py @@ -72,6 +72,18 @@ def map(self, func, iterables, **kwargs): "temperature", "temperature_gradient", "brillouin_spectrum", + "fourier_transform", + "amplitude_spectrum", + "power_spectrum", + "power_spectral_density", + "frequency_band_energy", + "stalta", + "kurtosis", + "envelope", + "correlation", + "tau_p", + "dispersion", + "phase_weighted_stack", ) # Valid categories (of instruments) @@ -87,7 +99,7 @@ def map(self, func, iterables, **kwargs): "file_version": 9, "experiment_id": 50, "instrument_id": 50, - "data_type": 20, + "data_type": 32, "data_category": 4, } @@ -216,13 +228,18 @@ def map(self, func, iterables, **kwargs): DEFAULT_COLORMAPS = { - "frequency-band energy": "Spectral_r", + "frequency_band_energy": "Spectral_r", "stalta": "RdGy_r", "kurtosis": "gnuplot2", - "fourier transform": "magma", - "power spectral density": "turbo", - "power spectrum": "turbo", - "amplitude spectrum": "turbo", + "envelope": "viridis", + "correlation": "RdBu_r", + "tau_p": "magma", + "dispersion": "turbo", + "phase_weighted_stack": "viridis", + "fourier_transform": "magma", + "power_spectral_density": "turbo", + "power_spectrum": "turbo", + "amplitude_spectrum": "turbo", "strain_rate": "RdBu_r", "strain": "seismic", "velocity": "viridis", diff --git a/dascore/proc/aggregate.py b/dascore/proc/aggregate.py index 3b3e99ae0..0b0ad6e6a 100644 --- a/dascore/proc/aggregate.py +++ b/dascore/proc/aggregate.py @@ -260,7 +260,7 @@ def sum( return aggregate.func(patch, dim=dim, method=np.nansum, dim_reduce=dim_reduce) -@patch_function() +@patch_function(data_type="") @compose_docstring(params=AGG_DOC_STR, notes=AGG_NOTES) def any( patch: PatchType, @@ -279,7 +279,7 @@ def any( return aggregate.func(patch, dim=dim, method=np.any, dim_reduce=dim_reduce) -@patch_function() +@patch_function(data_type="") @compose_docstring(params=AGG_DOC_STR, notes=AGG_NOTES) def all( patch: PatchType, diff --git a/dascore/proc/basic.py b/dascore/proc/basic.py index fc20fcb0f..5c7a1fb55 100644 --- a/dascore/proc/basic.py +++ b/dascore/proc/basic.py @@ -311,7 +311,7 @@ def imag(patch: PatchType) -> PatchType: return patch.new(data=np.imag(patch.data)) -@patch_function() +@patch_function(data_type="") def angle(patch: PatchType) -> PatchType: """ Return a new patch with the phase angles from the data array. @@ -325,7 +325,7 @@ def angle(patch: PatchType) -> PatchType: return patch.new(data=np.angle(patch.data)) -@patch_function() +@patch_function(data_type="") def normalize( self: PatchType, dim: str, @@ -387,7 +387,7 @@ def normalize( return self.new(data=new_data) -@patch_function() +@patch_function(data_type="") def standardize( self: PatchType, dim: str, @@ -793,7 +793,7 @@ def flip(patch, *dims, flip_coords=True): return patch.new(data=data, coords=coords) -@patch_function() +@patch_function(data_type="") def full(patch, fill_value): """ Return an identical patch with the data replaced by fill_value. diff --git a/dascore/proc/correlate.py b/dascore/proc/correlate.py index 383ee292c..fb74f97a5 100644 --- a/dascore/proc/correlate.py +++ b/dascore/proc/correlate.py @@ -34,7 +34,7 @@ def _get_source_fft(patch, dim, source, source_axis, samples): return out -@patch_function() +@patch_function(data_type="correlation") def correlate_shift( patch: PatchType, dim: str, undo_weighting: bool = True ) -> PatchType: @@ -86,7 +86,7 @@ def correlate_shift( return out -@patch_function() +@patch_function(data_type="correlation") def correlate( patch: PatchType, samples: bool = False, diff --git a/dascore/transform/differentiate.py b/dascore/transform/differentiate.py index d9e800037..e9954dbe0 100644 --- a/dascore/transform/differentiate.py +++ b/dascore/transform/differentiate.py @@ -70,7 +70,7 @@ def _strided_diff(order, patch, axes, dx_or_spacing, step): return new_data -@patch_function() +@patch_function(data_type="") def differentiate( patch: PatchType, dim: str | Sequence[str] | None, diff --git a/dascore/transform/dispersion.py b/dascore/transform/dispersion.py index d0e27a2da..80e2c1538 100644 --- a/dascore/transform/dispersion.py +++ b/dascore/transform/dispersion.py @@ -12,7 +12,7 @@ from dascore.utils.patch import patch_function -@patch_function(required_dims=("time", "distance")) +@patch_function(required_dims=("time", "distance"), data_type="dispersion") def dispersion_phase_shift( patch: PatchType, phase_velocities: Sequence[float], diff --git a/dascore/transform/fbe.py b/dascore/transform/fbe.py index cb74f2cb8..439c1c6d1 100644 --- a/dascore/transform/fbe.py +++ b/dascore/transform/fbe.py @@ -88,12 +88,12 @@ def fbe( patch = patch.pass_filter(**kwargs) fbe = ((patch**2).rolling(**{dim: window, "step": step}).mean() ** 0.5).update( - attrs={"data_type": "Frequency-Band Energy"} + attrs={"data_type": "frequency_band_energy"} ) if db: fbe = (10 * fbe.log10()).update( - attrs={"data_type": "Frequency-Band Energy", "data_units": "dB"} + attrs={"data_type": "frequency_band_energy", "data_units": "dB"} ) return fbe diff --git a/dascore/transform/fft.py b/dascore/transform/fft.py index f153aeb8c..b3a4381be 100644 --- a/dascore/transform/fft.py +++ b/dascore/transform/fft.py @@ -18,7 +18,7 @@ from dascore.utils.transformatter import FourierTransformatter -@patch_function() +@patch_function(data_type="fourier_transform") @deprecate( info="The Patch transform rfft is deprecated. Use dft instead.", removed_in="0.2.0", diff --git a/dascore/transform/fourier.py b/dascore/transform/fourier.py index a3fc7b3d5..7b1b46feb 100644 --- a/dascore/transform/fourier.py +++ b/dascore/transform/fourier.py @@ -39,9 +39,9 @@ from dascore.utils.transformatter import FourierTransformatter DFT_OUTPUT_DATA_TYPE_MAP = { - "AS": "Amplitude Spectrum", - "PS": "Power Spectrum", - "PSD": "Spectral Density", + "AS": "amplitude_spectrum", + "PS": "power_spectrum", + "PSD": "power_spectral_density", } DFT_OUTPUT_TYPES = ("FFT", *DFT_OUTPUT_DATA_TYPE_MAP) @@ -127,7 +127,7 @@ def _get_dft_attrs(patch, dims, new_coords, pad=False, output="FFT"): new["dims"] = new_coords.dims new["data_units"] = _get_dft_data_units(patch, dims) new["_pre_dft_data_type"] = new.get("data_type") - new["data_type"] = "fourier transform" + new["data_type"] = "fourier_transform" new["_dft_output"] = output new["_dft_padded"] = pad return PatchAttrs(**new) @@ -498,7 +498,7 @@ def _get_stft_dims(dim, dims, axis): return out -@patch_function() +@patch_function(data_type="fourier_transform") def stft( patch: PatchType, taper_window: str | ndarray | tuple[str | Any, ...] = "hann", @@ -606,6 +606,7 @@ def stft( "_stft_fft_mode": fft_mode, "_stft_mfft": window_samples, "_stft_performed": True, + "_pre_stft_data_type": patch.attrs.get("data_type"), "data_units": _get_data_units_from_dims(patch, dim, mul), } attrs = patch.attrs.drop("coords").update(**new_attrs) @@ -701,7 +702,10 @@ def istft(patch) -> PatchType: new_data = data_untrimmed[index] assert new_data.shape == cm.shape # Re-assemble and return new patch. - new_attrs = {i: v for i, v in patch.attrs.items() if not i.startswith("_stft")} + patch_attrs = dict(patch.attrs) + new_attrs = {i: v for i, v in patch_attrs.items() if not i.startswith("_stft")} + if "_pre_stft_data_type" in patch_attrs: + new_attrs["data_type"] = new_attrs.pop("_pre_stft_data_type") dim = patch.dims[time_axis] new_attrs["data_units"] = _get_data_units_from_dims(patch, dim, truediv) attrs = dc.PatchAttrs(**new_attrs).drop("coords") diff --git a/dascore/transform/hilbert.py b/dascore/transform/hilbert.py index c03347f38..aee740512 100644 --- a/dascore/transform/hilbert.py +++ b/dascore/transform/hilbert.py @@ -15,7 +15,7 @@ from dascore.utils.patch import patch_function -@patch_function() +@patch_function(data_type="") def hilbert(patch: PatchType, dim: str) -> PatchType: """ Perform a Hilbert transform on a patch. @@ -57,7 +57,7 @@ def hilbert(patch: PatchType, dim: str) -> PatchType: return patch.new(data=analytic_signal) -@patch_function() +@patch_function(data_type="envelope") def envelope(patch: PatchType, dim: str) -> PatchType: """ Calculate the envelope of a signal using the Hilbert transform. @@ -112,7 +112,7 @@ def __infer_transform_dim(patch, stack_dim): return next(iter(dims)) -@patch_function() +@patch_function(data_type="phase_weighted_stack") @compose_docstring(dim_reduce=DIM_REDUCE_DOCS) def phase_weighted_stack( patch: PatchType, diff --git a/dascore/transform/integrate.py b/dascore/transform/integrate.py index 5c4c12fad..d8e3ba483 100644 --- a/dascore/transform/integrate.py +++ b/dascore/transform/integrate.py @@ -80,7 +80,7 @@ def _get_indefinite_integral(patch, dxs_or_vals, axes): return out, patch.coords # coords shouldn't change -@patch_function() +@patch_function(data_type="") def integrate( patch: PatchType, dim: Sequence[str] | str | None, diff --git a/dascore/transform/kurtosis.py b/dascore/transform/kurtosis.py index 4f71ffd39..454b30e82 100644 --- a/dascore/transform/kurtosis.py +++ b/dascore/transform/kurtosis.py @@ -15,7 +15,7 @@ def _validate_window(winlen: float, step: float) -> int: """Convert window length in seconds to samples and validate.""" if winlen <= 0: raise ValueError("winlen must be positive.") - nwin = int(round(winlen / step)) + nwin = round(winlen / step) if nwin < 2: raise ValueError("winlen is too small for the sampling interval.") return nwin @@ -230,5 +230,5 @@ def kurtosis( return ( patch_t.new(data=out) .transpose(*orig_dims) - .update(attrs={"data_type": "Kurtosis", "data_units": ""}) + .update(attrs={"data_type": "kurtosis", "data_units": ""}) ) diff --git a/dascore/transform/spectro.py b/dascore/transform/spectro.py index 767715eef..d082349c3 100644 --- a/dascore/transform/spectro.py +++ b/dascore/transform/spectro.py @@ -58,7 +58,7 @@ def _get_new_dims(patch, dim, new_coord_name): return tuple([*dims, dim]) -@patch_function() +@patch_function(data_type="fourier_transform") @deprecate( info="Use Patch.stft() instead.", since="0.1.11", diff --git a/dascore/transform/stalta.py b/dascore/transform/stalta.py index 631504b97..6f66906d1 100644 --- a/dascore/transform/stalta.py +++ b/dascore/transform/stalta.py @@ -53,4 +53,4 @@ def stalta( sta_data = patch.rolling(**{dim: sta}).mean() lta_data = patch.rolling(**{dim: lta}).mean() - return (sta_data / lta_data).update(attrs={"data_type": "STALTA", "data_units": ""}) + return (sta_data / lta_data).update(attrs={"data_type": "stalta", "data_units": ""}) diff --git a/dascore/transform/taup.py b/dascore/transform/taup.py index 2d065e7e1..12a7d1360 100644 --- a/dascore/transform/taup.py +++ b/dascore/transform/taup.py @@ -75,7 +75,7 @@ def _jit_taup_general(data, distance, dt, p_vals): return two_sided_p_vals, taup -@patch_function(required_dims=("time", "distance")) +@patch_function(required_dims=("time", "distance"), data_type="tau_p") def tau_p( patch: PatchType, velocities: NDArray[np.floating], diff --git a/dascore/utils/patch.py b/dascore/utils/patch.py index 6b26c2c59..5a5ae6d2c 100644 --- a/dascore/utils/patch.py +++ b/dascore/utils/patch.py @@ -186,6 +186,7 @@ def patch_function( required_attrs: attr_type = None, history: Literal["full", "method_name", None] = "full", validate_call: bool = False, + data_type: str | None = None, ): """ Decorator to mark a function as a patch method. @@ -208,6 +209,10 @@ def patch_function( If True, use pydantic to validate the function call. This can save quite a lot of code in validation checks, but does have some overhead. See [validate_call](https://docs.pydantic.dev/latest/api/validate_call/). + data_type + Controls the output patch's ``data_type`` attr. If None, leave the + returned patch's ``data_type`` unchanged. Otherwise, set to specified + value. Use an empty string ("") to clear. Examples -------- @@ -236,6 +241,16 @@ def patch_function( ... option: Literal["min", "max", None] = None, ... ): ... ... + >>> + >>> # 4. A patch method which sets the output data_type. + >>> @dc.patch_function(data_type="strain_rate") + ... def do_strain_rate(patch): + ... ... + >>> + >>> # 5. A patch method which clears the output data_type. + >>> @dc.patch_function(data_type="") + ... def do_unknown_quantity(patch): + ... ... Notes ----- @@ -263,6 +278,9 @@ def _func(patch, *args, **kwargs): ) check_patch_attrs(patch, required_attrs) out: PatchType = func(patch, *args, **kwargs) + attr_updates = {} + if data_type is not None: + attr_updates["data_type"] = data_type # attach history string. Need to consider something a bit less hacky. if out is not patch and hasattr(out, "attrs"): hist_str = _get_history_str( @@ -271,7 +289,9 @@ def _func(patch, *args, **kwargs): if hist_str: hist = list(out.attrs.history) hist.append(hist_str) - out = out.update_attrs(history=hist) + attr_updates["history"] = hist + if attr_updates and hasattr(out, "attrs"): + out = out.update_attrs(**attr_updates) return out # Attach original function. Although we want to encourage raw_function diff --git a/docs/notes/notes.qmd b/docs/notes/notes.qmd index bcf97e9da..735e78cda 100644 --- a/docs/notes/notes.qmd +++ b/docs/notes/notes.qmd @@ -3,3 +3,8 @@ title: Notes --- This section of the documentation provides understanding-oriented explanation for DASCore implementation and design decisions. + +- [PatchAttrs](patch_attrs.qmd) +- [Documentation Strategy](doc_strategy.qmd) +- [Fourier Transforms](dft_notes.qmd) +- [Velocity to Strain Rate](velocity_to_strain_rate.qmd) diff --git a/docs/notes/patch_attrs.qmd b/docs/notes/patch_attrs.qmd new file mode 100644 index 000000000..579b08542 --- /dev/null +++ b/docs/notes/patch_attrs.qmd @@ -0,0 +1,35 @@ +--- +title: PatchAttrs +--- + +[`PatchAttrs`](`dascore.core.attrs.PatchAttrs`) stores metadata about a [`Patch`](`dascore.Patch`). Some attributes describe identity or provenance, some summarize coordinates, and some describe the data array itself. + +This note explains how DASCore interprets a few important attributes, and its internal policies around these attributes. + +## `data_type` + +`data_type` is an optional label for the kind of data contained in a patch. It is useful for display defaults, plotting choices, grouping, and quick inspection, but it is not the canonical source of physical meaning, rather the data and coordinate units, as well as the patch history serve this purpose. + +:::{.callout-note} +A stale or misleading `data_type` is much worse than an empty one. +::: + +| Situation | `data_type` behavior | +|---|---| +| Output is still the same measured quantity, just filtered/resampled/selected/reordered | Preserve existing `data_type`. | +| Output is a known derived product with a stable meaning | Set a specific snake_case `data_type`. | +| Output changes physical meaning but no stable label is appropriate | Clear `data_type` to `""`. | + +DASCore-assigned `data_type` values should be snake_case and listed in `VALID_DATA_TYPES` in [`dascore.constants`](`dascore.constants`). Correctness-critical code should prefer units, coordinates, and explicit validation. + +### Patch functions + +[`patch_function`](`dascore.utils.patch.patch_function`) can manage output `data_type` for patch methods. + +| Decorator value | Behavior | +|---|---| +| `data_type=None` | Preserve the returned patch's `data_type`. This is the default for backward compatibility. | +| `data_type=""` | Clear the returned patch's `data_type`. | +| `data_type="some_value"` | Set the returned patch's `data_type` to that value. | + +Functions may still require a specific input label with `required_attrs`, for example `required_attrs={"data_type": "velocity"}`. This should only be used when the function's assumptions truly depend on that label and are documented. diff --git a/docs/tutorial/patch.qmd b/docs/tutorial/patch.qmd index bd77d0133..c0be96a37 100644 --- a/docs/tutorial/patch.qmd +++ b/docs/tutorial/patch.qmd @@ -202,6 +202,8 @@ Markdown(df_str) Specific data formats may also add attributes (e.g. "gauge_length", "pulse_width"), but this depends on the parser. +The `data_type` attribute is an optional label for the kind of data in the patch. It is useful for display defaults and quick inspection, but physical interpretation should come from `data_units`, coordinate units, and `history`. See the [PatchAttrs note](../notes/patch_attrs.qmd) for more detail. + ## String representation DASCore Patches have a useful string representation: diff --git a/scripts/_templates/_quarto.yml b/scripts/_templates/_quarto.yml index ae2824a21..783d080d5 100644 --- a/scripts/_templates/_quarto.yml +++ b/scripts/_templates/_quarto.yml @@ -202,6 +202,9 @@ website: - text: Documentation Strategy href: notes/doc_strategy.qmd + - text: PatchAttrs + href: notes/patch_attrs.qmd + - text: Fourier Transforms href: notes/dft_notes.qmd diff --git a/tests/test_core/test_attrs.py b/tests/test_core/test_attrs.py index 1fbaca817..554917d74 100644 --- a/tests/test_core/test_attrs.py +++ b/tests/test_core/test_attrs.py @@ -8,6 +8,7 @@ from pydantic import ValidationError import dascore as dc +from dascore.constants import VALID_DATA_TYPES, max_lens from dascore.core.attrs import ( PatchAttrs, ) @@ -149,6 +150,13 @@ def test_supports_extra_attrs(self): assert out.bob == "doesnt" assert out.bill_min == 12 + def test_valid_data_types_fit_max_length(self): + """Ensure supported data_type values fit the declared attr length.""" + max_len = max_lens["data_type"] + + for data_type in VALID_DATA_TYPES: + assert len(data_type) <= max_len + def test_flat_dump(self, more_coords_attrs): """Ensure flat dump flattens out the coords.""" out = more_coords_attrs.flat_dump() diff --git a/tests/test_transform/test_fbe.py b/tests/test_transform/test_fbe.py index 2f069e6d8..c422184f0 100644 --- a/tests/test_transform/test_fbe.py +++ b/tests/test_transform/test_fbe.py @@ -17,14 +17,14 @@ def test_runs_time_filter(self, random_patch): out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=False) assert out.dims == random_patch.dims - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_runs_distance_filter(self, random_patch): """Ensure FBE runs along the distance dimension.""" out = random_patch.fbe(distance=(0.01, 0.05), window=5, step=1, db=False) assert out.dims == random_patch.dims - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_db_false_matches_expected_rms(self, random_patch): """Ensure db=False returns filtered rolling RMS.""" @@ -53,13 +53,13 @@ def test_attrs_when_not_db(self, random_patch): """Ensure non-db output metadata are set.""" out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=False) - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_attrs_when_db(self, random_patch): """Ensure db output metadata are set.""" out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=True) - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" assert out.attrs.data_units == ureg.dB def test_step_defaults_to_inverse_sampling_rate(self, random_patch): @@ -77,13 +77,13 @@ def test_open_ended_lowpass_filter(self, random_patch): """Ensure open-ended lowpass filters are accepted.""" out = random_patch.fbe(time=(None, 100), window=0.01, step=0.01, db=False) - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_open_ended_highpass_filter(self, random_patch): """Ensure open-ended highpass filters are accepted.""" out = random_patch.fbe(time=(10, None), window=0.01, step=0.01, db=False) - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_invalid_frequency_range_raises(self, random_patch): """Ensure invalid filter ranges raise.""" diff --git a/tests/test_transform/test_fourier.py b/tests/test_transform/test_fourier.py index 1e1bf4938..356536896 100644 --- a/tests/test_transform/test_fourier.py +++ b/tests/test_transform/test_fourier.py @@ -223,7 +223,7 @@ def test_transform_single_dim( def test_datatype_changed(self, fft_sin_patch_time, sin_patch): """Ensure the data_type attr is changed after transform.""" assert sin_patch.attrs.data_type == "strain_rate" - assert fft_sin_patch_time.attrs.data_type == "fourier transform" + assert fft_sin_patch_time.attrs.data_type == "fourier_transform" def test_dft_output_attr_set(self, fft_sin_patch_time): """Ensure the DFT output type is tracked.""" @@ -248,10 +248,10 @@ def test_display(self, fft_sin_patch_time): @pytest.mark.parametrize( ("output", "data_type"), [ - ("FFT", "fourier transform"), - ("AS", "Amplitude Spectrum"), - ("PS", "Power Spectrum"), - ("PSD", "Spectral Density"), + ("FFT", "fourier_transform"), + ("AS", "amplitude_spectrum"), + ("PS", "power_spectrum"), + ("PSD", "power_spectral_density"), ], ) def test_output_spectral_representations(self, sin_patch, output, data_type): @@ -484,7 +484,7 @@ class TestSTFT: def test_numeric_window_with_timedelta_coord(self): """ - stft with a numeric window length should work when the time + Stft with a numeric window length should work when the time coordinate is timedelta64 (not just datetime64); see #604. """ patch = dc.get_example_patch() diff --git a/tests/test_transform/test_kurtosis.py b/tests/test_transform/test_kurtosis.py index aee52d9a7..0819a6afa 100644 --- a/tests/test_transform/test_kurtosis.py +++ b/tests/test_transform/test_kurtosis.py @@ -90,7 +90,7 @@ def test_windowed_runs(self, random_patch): assert out.dims == random_patch.dims assert out.data.shape == random_patch.data.shape - assert out.attrs.data_type == "Kurtosis" + assert out.attrs.data_type == "kurtosis" assert out.attrs.data_units is None def test_recursive_runs(self, random_patch): @@ -99,7 +99,7 @@ def test_recursive_runs(self, random_patch): assert out.dims == random_patch.dims assert out.data.shape == random_patch.data.shape - assert out.attrs.data_type == "Kurtosis" + assert out.attrs.data_type == "kurtosis" assert out.attrs.data_units is None def test_restores_original_dimension_order(self, random_patch): diff --git a/tests/test_transform/test_stalta.py b/tests/test_transform/test_stalta.py index af4187224..ad53a4082 100644 --- a/tests/test_transform/test_stalta.py +++ b/tests/test_transform/test_stalta.py @@ -15,7 +15,7 @@ def test_runs_time_dimension(self, random_patch): assert out.dims == random_patch.dims assert out.data.shape == random_patch.data.shape - assert out.attrs.data_type == "STALTA" + assert out.attrs.data_type == "stalta" assert out.attrs.data_units is None def test_runs_distance_dimension(self, random_patch): @@ -24,7 +24,7 @@ def test_runs_distance_dimension(self, random_patch): assert out.dims == random_patch.dims assert out.data.shape == random_patch.data.shape - assert out.attrs.data_type == "STALTA" + assert out.attrs.data_type == "stalta" assert out.attrs.data_units is None def test_matches_expected_time_ratio(self, random_patch): @@ -51,7 +51,7 @@ def test_attrs_are_set(self, random_patch): """Ensure output metadata are set.""" out = random_patch.stalta(time=(0.01, 0.05)) - assert out.attrs.data_type == "STALTA" + assert out.attrs.data_type == "stalta" assert out.attrs.data_units is None def test_missing_dimension_kwargs_raise(self, random_patch): diff --git a/tests/test_utils/test_patch_utils.py b/tests/test_utils/test_patch_utils.py index 521d0b5ca..e516de8d6 100644 --- a/tests/test_utils/test_patch_utils.py +++ b/tests/test_utils/test_patch_utils.py @@ -161,6 +161,89 @@ def some_func( with pytest.raises(pydantic.ValidationError): some_func(patch, some_int=10, specific_float=20.0) + def test_data_type(self, random_patch): + """Ensure the decorator can set the output data_type.""" + + @patch_function(data_type="strain_rate") + def some_func(patch): + """A test function for setting the output data_type.""" + return patch.new(data=patch.data + 1) + + out = some_func(random_patch) + + assert out.attrs.data_type == "strain_rate" + + def test_data_type_overwrites_returned_patch_attr(self, random_patch): + """Ensure the decorator data_type takes precedence.""" + + @patch_function(data_type="strain_rate") + def some_func(patch): + """A test function with a conflicting output data_type.""" + return patch.new(data=patch.data + 1, attrs={"data_type": "velocity"}) + + out = some_func(random_patch) + + assert out.attrs.data_type == "strain_rate" + + def test_data_type_none_preserves_existing_behavior(self, random_patch): + """Ensure the default data_type argument leaves attrs unchanged.""" + + @patch_function() + def some_func(patch): + """A test function without decorator-managed data_type.""" + return patch.new(data=patch.data + 1, attrs={"data_type": "velocity"}) + + out = some_func(random_patch) + + assert out.attrs.data_type == "velocity" + + def test_data_type_none_preserves_inherited_data_type(self, random_patch): + """Ensure the default data_type argument preserves inherited attrs.""" + + @patch_function() + def some_func(patch): + """A test function without decorator-managed data_type.""" + return patch.new(data=patch.data + 1) + + patch = random_patch.update_attrs(data_type="velocity") + out = some_func(patch) + + assert out.attrs.data_type == "velocity" + + def test_empty_data_type_clears_returned_patch_attr(self, random_patch): + """Ensure the decorator can clear the output data_type.""" + + @patch_function(data_type="") + def some_func(patch): + """A test function for clearing data_type.""" + return patch.new(data=patch.data + 1, attrs={"data_type": "velocity"}) + + out = some_func(random_patch) + + assert out.attrs.data_type == "" + + def test_data_type_and_history_use_one_attr_update(self, random_patch, monkeypatch): + """Ensure decorator-managed attrs are updated together.""" + update_count = 0 + original_update_attrs = dc.Patch.update_attrs + + def update_attrs(self, **attrs): + nonlocal update_count + update_count += 1 + return original_update_attrs(self, **attrs) + + @patch_function(data_type="strain_rate") + def some_func(patch): + """A test function for setting data_type and history.""" + return patch.new(data=patch.data + 1) + + monkeypatch.setattr(dc.Patch, "update_attrs", update_attrs) + out = some_func(random_patch) + + assert out.attrs.data_type == "strain_rate" + assert len(out.attrs.history) == len(random_patch.attrs.history) + 1 + assert update_count == 1 + class TestHistory: """Tests for tracking patch processing history."""