From f81df7d901bdc0c2a56f3e6cb78ae6bdf613dfc4 Mon Sep 17 00:00:00 2001 From: Derrick Chambers Date: Thu, 4 Jun 2026 12:47:03 +0200 Subject: [PATCH] add option to set data_type in patch_function decorator --- dascore/constants.py | 5 ++- dascore/transform/fbe.py | 10 ++--- dascore/transform/kurtosis.py | 8 +--- dascore/transform/stalta.py | 4 +- dascore/transform/strain.py | 7 ++-- dascore/utils/patch.py | 15 ++++++- tests/test_transform/test_fbe.py | 12 +++--- tests/test_transform/test_kurtosis.py | 4 +- tests/test_transform/test_stalta.py | 6 +-- tests/test_utils/test_patch_utils.py | 58 +++++++++++++++++++++++++++ 10 files changed, 97 insertions(+), 32 deletions(-) diff --git a/dascore/constants.py b/dascore/constants.py index b64339f12..3fe551ad7 100644 --- a/dascore/constants.py +++ b/dascore/constants.py @@ -71,6 +71,9 @@ def map(self, func, iterables, **kwargs): "strain", "temperature", "temperature_gradient", + "frequency_band_energy", + "kurtosis", + "stalta", ) # Valid categories (of instruments) @@ -86,7 +89,7 @@ def map(self, func, iterables, **kwargs): "file_version": 9, "experiment_id": 50, "instrument_id": 50, - "data_type": 20, + "data_type": 21, "data_category": 4, } diff --git a/dascore/transform/fbe.py b/dascore/transform/fbe.py index cb74f2cb8..a96885071 100644 --- a/dascore/transform/fbe.py +++ b/dascore/transform/fbe.py @@ -11,7 +11,7 @@ from dascore.utils.time import to_float -@patch_function() +@patch_function(data_type="frequency_band_energy") def fbe( patch: PatchType, window: float, @@ -87,13 +87,9 @@ def fbe( patch = patch.pass_filter(**kwargs) - fbe = ((patch**2).rolling(**{dim: window, "step": step}).mean() ** 0.5).update( - attrs={"data_type": "Frequency-Band Energy"} - ) + fbe = (patch**2).rolling(**{dim: window, "step": step}).mean() ** 0.5 if db: - fbe = (10 * fbe.log10()).update( - attrs={"data_type": "Frequency-Band Energy", "data_units": "dB"} - ) + fbe = (10 * fbe.log10()).update(attrs={"data_units": "dB"}) return fbe diff --git a/dascore/transform/kurtosis.py b/dascore/transform/kurtosis.py index 4f71ffd39..f80437639 100644 --- a/dascore/transform/kurtosis.py +++ b/dascore/transform/kurtosis.py @@ -110,7 +110,7 @@ def _recursive_kurtosis( return out -@patch_function() +@patch_function(data_type="kurtosis") def kurtosis( patch: PatchType, winlen: float, @@ -227,8 +227,4 @@ def kurtosis( out = out.reshape(orig_shape) - return ( - patch_t.new(data=out) - .transpose(*orig_dims) - .update(attrs={"data_type": "Kurtosis", "data_units": ""}) - ) + return patch_t.new(data=out).transpose(*orig_dims).update(attrs={"data_units": ""}) diff --git a/dascore/transform/stalta.py b/dascore/transform/stalta.py index 631504b97..adb47e2ff 100644 --- a/dascore/transform/stalta.py +++ b/dascore/transform/stalta.py @@ -9,7 +9,7 @@ from dascore.utils.patch import patch_function -@patch_function() +@patch_function(data_type="stalta") def stalta( patch: PatchType, **kwargs, @@ -53,4 +53,4 @@ def stalta( sta_data = patch.rolling(**{dim: sta}).mean() lta_data = patch.rolling(**{dim: lta}).mean() - return (sta_data / lta_data).update(attrs={"data_type": "STALTA", "data_units": ""}) + return (sta_data / lta_data).update(attrs={"data_units": ""}) diff --git a/dascore/transform/strain.py b/dascore/transform/strain.py index 45d090be7..03a01d9e6 100644 --- a/dascore/transform/strain.py +++ b/dascore/transform/strain.py @@ -17,6 +17,7 @@ @patch_function( required_dims=("distance",), required_attrs={"data_type": "velocity"}, + data_type="strain_rate", ) def velocity_to_strain_rate( patch: PatchType, @@ -115,15 +116,14 @@ def velocity_to_strain_rate( patch = differentiate.func( patch, dim="distance", order=order, step=step_multiple // 2 ) - new_attrs = patch.attrs.update( - data_type="strain_rate", gauge_length=step * step_multiple - ) + new_attrs = patch.attrs.update(gauge_length=step * step_multiple) return patch.update(attrs=new_attrs) @patch_function( required_dims=("distance",), required_attrs={"data_type": "velocity"}, + data_type="strain_rate", ) def velocity_to_strain_rate_edgeless( patch: PatchType, @@ -202,7 +202,6 @@ def velocity_to_strain_rate_edgeless( new_data_units = data_units / dist_units new_attrs = patch.attrs.update( - data_type="strain_rate", gauge_length=distance_step * step_multiple, data_units=new_data_units, ) diff --git a/dascore/utils/patch.py b/dascore/utils/patch.py index 6b26c2c59..8df01d661 100644 --- a/dascore/utils/patch.py +++ b/dascore/utils/patch.py @@ -186,6 +186,7 @@ def patch_function( required_attrs: attr_type = None, history: Literal["full", "method_name", None] = "full", validate_call: bool = False, + data_type: str | None = None, ): """ Decorator to mark a function as a patch method. @@ -208,6 +209,8 @@ def patch_function( If True, use pydantic to validate the function call. This can save quite a lot of code in validation checks, but does have some overhead. See [validate_call](https://docs.pydantic.dev/latest/api/validate_call/). + data_type + If not None, set the output patch's data_type attr to this value. Examples -------- @@ -236,6 +239,11 @@ def patch_function( ... option: Literal["min", "max", None] = None, ... ): ... ... + >>> + >>> # 4. A patch method which sets the output data_type. + >>> @dc.patch_function(data_type="strain_rate") + ... def do_strain_rate(patch): + ... ... Notes ----- @@ -263,6 +271,9 @@ def _func(patch, *args, **kwargs): ) check_patch_attrs(patch, required_attrs) out: PatchType = func(patch, *args, **kwargs) + attr_updates = {} + if data_type is not None: + attr_updates["data_type"] = data_type # attach history string. Need to consider something a bit less hacky. if out is not patch and hasattr(out, "attrs"): hist_str = _get_history_str( @@ -271,7 +282,9 @@ def _func(patch, *args, **kwargs): if hist_str: hist = list(out.attrs.history) hist.append(hist_str) - out = out.update_attrs(history=hist) + attr_updates["history"] = hist + if attr_updates and hasattr(out, "attrs"): + out = out.update_attrs(**attr_updates) return out # Attach original function. Although we want to encourage raw_function diff --git a/tests/test_transform/test_fbe.py b/tests/test_transform/test_fbe.py index 2f069e6d8..c422184f0 100644 --- a/tests/test_transform/test_fbe.py +++ b/tests/test_transform/test_fbe.py @@ -17,14 +17,14 @@ def test_runs_time_filter(self, random_patch): out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=False) assert out.dims == random_patch.dims - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_runs_distance_filter(self, random_patch): """Ensure FBE runs along the distance dimension.""" out = random_patch.fbe(distance=(0.01, 0.05), window=5, step=1, db=False) assert out.dims == random_patch.dims - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_db_false_matches_expected_rms(self, random_patch): """Ensure db=False returns filtered rolling RMS.""" @@ -53,13 +53,13 @@ def test_attrs_when_not_db(self, random_patch): """Ensure non-db output metadata are set.""" out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=False) - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_attrs_when_db(self, random_patch): """Ensure db output metadata are set.""" out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=True) - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" assert out.attrs.data_units == ureg.dB def test_step_defaults_to_inverse_sampling_rate(self, random_patch): @@ -77,13 +77,13 @@ def test_open_ended_lowpass_filter(self, random_patch): """Ensure open-ended lowpass filters are accepted.""" out = random_patch.fbe(time=(None, 100), window=0.01, step=0.01, db=False) - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_open_ended_highpass_filter(self, random_patch): """Ensure open-ended highpass filters are accepted.""" out = random_patch.fbe(time=(10, None), window=0.01, step=0.01, db=False) - assert out.attrs.data_type == "Frequency-Band Energy" + assert out.attrs.data_type == "frequency_band_energy" def test_invalid_frequency_range_raises(self, random_patch): """Ensure invalid filter ranges raise.""" diff --git a/tests/test_transform/test_kurtosis.py b/tests/test_transform/test_kurtosis.py index aee52d9a7..0819a6afa 100644 --- a/tests/test_transform/test_kurtosis.py +++ b/tests/test_transform/test_kurtosis.py @@ -90,7 +90,7 @@ def test_windowed_runs(self, random_patch): assert out.dims == random_patch.dims assert out.data.shape == random_patch.data.shape - assert out.attrs.data_type == "Kurtosis" + assert out.attrs.data_type == "kurtosis" assert out.attrs.data_units is None def test_recursive_runs(self, random_patch): @@ -99,7 +99,7 @@ def test_recursive_runs(self, random_patch): assert out.dims == random_patch.dims assert out.data.shape == random_patch.data.shape - assert out.attrs.data_type == "Kurtosis" + assert out.attrs.data_type == "kurtosis" assert out.attrs.data_units is None def test_restores_original_dimension_order(self, random_patch): diff --git a/tests/test_transform/test_stalta.py b/tests/test_transform/test_stalta.py index af4187224..ad53a4082 100644 --- a/tests/test_transform/test_stalta.py +++ b/tests/test_transform/test_stalta.py @@ -15,7 +15,7 @@ def test_runs_time_dimension(self, random_patch): assert out.dims == random_patch.dims assert out.data.shape == random_patch.data.shape - assert out.attrs.data_type == "STALTA" + assert out.attrs.data_type == "stalta" assert out.attrs.data_units is None def test_runs_distance_dimension(self, random_patch): @@ -24,7 +24,7 @@ def test_runs_distance_dimension(self, random_patch): assert out.dims == random_patch.dims assert out.data.shape == random_patch.data.shape - assert out.attrs.data_type == "STALTA" + assert out.attrs.data_type == "stalta" assert out.attrs.data_units is None def test_matches_expected_time_ratio(self, random_patch): @@ -51,7 +51,7 @@ def test_attrs_are_set(self, random_patch): """Ensure output metadata are set.""" out = random_patch.stalta(time=(0.01, 0.05)) - assert out.attrs.data_type == "STALTA" + assert out.attrs.data_type == "stalta" assert out.attrs.data_units is None def test_missing_dimension_kwargs_raise(self, random_patch): diff --git a/tests/test_utils/test_patch_utils.py b/tests/test_utils/test_patch_utils.py index 521d0b5ca..18ff238bc 100644 --- a/tests/test_utils/test_patch_utils.py +++ b/tests/test_utils/test_patch_utils.py @@ -161,6 +161,64 @@ def some_func( with pytest.raises(pydantic.ValidationError): some_func(patch, some_int=10, specific_float=20.0) + def test_data_type(self, random_patch): + """Ensure the decorator can set the output data_type.""" + + @patch_function(data_type="strain_rate") + def some_func(patch): + """A test function for setting the output data_type.""" + return patch.new(data=patch.data + 1) + + out = some_func(random_patch) + + assert out.attrs.data_type == "strain_rate" + + def test_data_type_overwrites_returned_patch_attr(self, random_patch): + """Ensure the decorator data_type takes precedence.""" + + @patch_function(data_type="strain_rate") + def some_func(patch): + """A test function with a conflicting output data_type.""" + return patch.new(data=patch.data + 1, attrs={"data_type": "velocity"}) + + out = some_func(random_patch) + + assert out.attrs.data_type == "strain_rate" + + def test_data_type_none_preserves_existing_behavior(self, random_patch): + """Ensure the default data_type argument leaves attrs unchanged.""" + + @patch_function() + def some_func(patch): + """A test function without decorator-managed data_type.""" + return patch.new(data=patch.data + 1, attrs={"data_type": "velocity"}) + + out = some_func(random_patch) + + assert out.attrs.data_type == "velocity" + + def test_data_type_and_history_use_one_attr_update(self, random_patch, monkeypatch): + """Ensure decorator-managed attrs are updated together.""" + update_count = 0 + original_update_attrs = dc.Patch.update_attrs + + def update_attrs(self, **attrs): + nonlocal update_count + update_count += 1 + return original_update_attrs(self, **attrs) + + @patch_function(data_type="strain_rate") + def some_func(patch): + """A test function for setting data_type and history.""" + return patch.new(data=patch.data + 1) + + monkeypatch.setattr(dc.Patch, "update_attrs", update_attrs) + out = some_func(random_patch) + + assert out.attrs.data_type == "strain_rate" + assert len(out.attrs.history) == len(random_patch.attrs.history) + 1 + assert update_count == 1 + class TestHistory: """Tests for tracking patch processing history."""