Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dascore/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def map(self, func, iterables, **kwargs):
"strain",
"temperature",
"temperature_gradient",
"frequency_band_energy",
"kurtosis",
"stalta",
)

# Valid categories (of instruments)
Expand All @@ -86,7 +89,7 @@ def map(self, func, iterables, **kwargs):
"file_version": 9,
"experiment_id": 50,
"instrument_id": 50,
"data_type": 20,
"data_type": 21,
"data_category": 4,
}

Expand Down
10 changes: 3 additions & 7 deletions dascore/transform/fbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dascore.utils.time import to_float


@patch_function()
@patch_function(data_type="frequency_band_energy")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restore FBE default colormap after renaming data_type

For FBE outputs, this new frequency_band_energy value no longer matches the existing DEFAULT_COLORMAPS key ("frequency-band energy"), while waterfall(cmap="default") lowercases the patch data_type and looks it up directly before falling back to "bwr". As a result, any FBE patch plotted with the default waterfall colormap now silently loses its prior Spectral_r coloring; please update the colormap key alongside the data_type rename.

Useful? React with 👍 / 👎.

def fbe(
patch: PatchType,
window: float,
Expand Down Expand Up @@ -87,13 +87,9 @@ def fbe(

patch = patch.pass_filter(**kwargs)

fbe = ((patch**2).rolling(**{dim: window, "step": step}).mean() ** 0.5).update(
attrs={"data_type": "Frequency-Band Energy"}
)
fbe = (patch**2).rolling(**{dim: window, "step": step}).mean() ** 0.5

if db:
fbe = (10 * fbe.log10()).update(
attrs={"data_type": "Frequency-Band Energy", "data_units": "dB"}
)
fbe = (10 * fbe.log10()).update(attrs={"data_units": "dB"})

return fbe
8 changes: 2 additions & 6 deletions dascore/transform/kurtosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _recursive_kurtosis(
return out


@patch_function()
@patch_function(data_type="kurtosis")
def kurtosis(
patch: PatchType,
winlen: float,
Expand Down Expand Up @@ -227,8 +227,4 @@ def kurtosis(

out = out.reshape(orig_shape)

return (
patch_t.new(data=out)
.transpose(*orig_dims)
.update(attrs={"data_type": "Kurtosis", "data_units": ""})
)
return patch_t.new(data=out).transpose(*orig_dims).update(attrs={"data_units": ""})
4 changes: 2 additions & 2 deletions dascore/transform/stalta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dascore.utils.patch import patch_function


@patch_function()
@patch_function(data_type="stalta")
def stalta(
patch: PatchType,
**kwargs,
Expand Down Expand Up @@ -53,4 +53,4 @@ def stalta(
sta_data = patch.rolling(**{dim: sta}).mean()
lta_data = patch.rolling(**{dim: lta}).mean()

return (sta_data / lta_data).update(attrs={"data_type": "STALTA", "data_units": ""})
return (sta_data / lta_data).update(attrs={"data_units": ""})
7 changes: 3 additions & 4 deletions dascore/transform/strain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
@patch_function(
required_dims=("distance",),
required_attrs={"data_type": "velocity"},
data_type="strain_rate",
)
def velocity_to_strain_rate(
patch: PatchType,
Expand Down Expand Up @@ -115,15 +116,14 @@ def velocity_to_strain_rate(
patch = differentiate.func(
patch, dim="distance", order=order, step=step_multiple // 2
)
new_attrs = patch.attrs.update(
data_type="strain_rate", gauge_length=step * step_multiple
)
new_attrs = patch.attrs.update(gauge_length=step * step_multiple)
return patch.update(attrs=new_attrs)


@patch_function(
required_dims=("distance",),
required_attrs={"data_type": "velocity"},
data_type="strain_rate",
)
def velocity_to_strain_rate_edgeless(
patch: PatchType,
Expand Down Expand Up @@ -202,7 +202,6 @@ def velocity_to_strain_rate_edgeless(
new_data_units = data_units / dist_units

new_attrs = patch.attrs.update(
data_type="strain_rate",
gauge_length=distance_step * step_multiple,
data_units=new_data_units,
)
Expand Down
15 changes: 14 additions & 1 deletion dascore/utils/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def patch_function(
required_attrs: attr_type = None,
history: Literal["full", "method_name", None] = "full",
validate_call: bool = False,
data_type: str | None = None,
):
"""
Decorator to mark a function as a patch method.
Expand All @@ -208,6 +209,8 @@ def patch_function(
If True, use pydantic to validate the function call. This can save
quite a lot of code in validation checks, but does have some overhead.
See [validate_call](https://docs.pydantic.dev/latest/api/validate_call/).
data_type
If not None, set the output patch's data_type attr to this value.

Examples
--------
Expand Down Expand Up @@ -236,6 +239,11 @@ def patch_function(
... option: Literal["min", "max", None] = None,
... ):
... ...
>>>
>>> # 4. A patch method which sets the output data_type.
>>> @dc.patch_function(data_type="strain_rate")
... def do_strain_rate(patch):
... ...

Notes
-----
Expand Down Expand Up @@ -263,6 +271,9 @@ def _func(patch, *args, **kwargs):
)
check_patch_attrs(patch, required_attrs)
out: PatchType = func(patch, *args, **kwargs)
attr_updates = {}
if data_type is not None:
attr_updates["data_type"] = data_type
# attach history string. Need to consider something a bit less hacky.
if out is not patch and hasattr(out, "attrs"):
hist_str = _get_history_str(
Expand All @@ -271,7 +282,9 @@ def _func(patch, *args, **kwargs):
if hist_str:
hist = list(out.attrs.history)
hist.append(hist_str)
out = out.update_attrs(history=hist)
attr_updates["history"] = hist
if attr_updates and hasattr(out, "attrs"):
out = out.update_attrs(**attr_updates)
return out

# Attach original function. Although we want to encourage raw_function
Expand Down
12 changes: 6 additions & 6 deletions tests/test_transform/test_fbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def test_runs_time_filter(self, random_patch):
out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=False)

assert out.dims == random_patch.dims
assert out.attrs.data_type == "Frequency-Band Energy"
assert out.attrs.data_type == "frequency_band_energy"

def test_runs_distance_filter(self, random_patch):
"""Ensure FBE runs along the distance dimension."""
out = random_patch.fbe(distance=(0.01, 0.05), window=5, step=1, db=False)

assert out.dims == random_patch.dims
assert out.attrs.data_type == "Frequency-Band Energy"
assert out.attrs.data_type == "frequency_band_energy"

def test_db_false_matches_expected_rms(self, random_patch):
"""Ensure db=False returns filtered rolling RMS."""
Expand Down Expand Up @@ -53,13 +53,13 @@ def test_attrs_when_not_db(self, random_patch):
"""Ensure non-db output metadata are set."""
out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=False)

assert out.attrs.data_type == "Frequency-Band Energy"
assert out.attrs.data_type == "frequency_band_energy"

def test_attrs_when_db(self, random_patch):
"""Ensure db output metadata are set."""
out = random_patch.fbe(time=(10, 100), window=0.01, step=0.01, db=True)

assert out.attrs.data_type == "Frequency-Band Energy"
assert out.attrs.data_type == "frequency_band_energy"
assert out.attrs.data_units == ureg.dB

def test_step_defaults_to_inverse_sampling_rate(self, random_patch):
Expand All @@ -77,13 +77,13 @@ def test_open_ended_lowpass_filter(self, random_patch):
"""Ensure open-ended lowpass filters are accepted."""
out = random_patch.fbe(time=(None, 100), window=0.01, step=0.01, db=False)

assert out.attrs.data_type == "Frequency-Band Energy"
assert out.attrs.data_type == "frequency_band_energy"

def test_open_ended_highpass_filter(self, random_patch):
"""Ensure open-ended highpass filters are accepted."""
out = random_patch.fbe(time=(10, None), window=0.01, step=0.01, db=False)

assert out.attrs.data_type == "Frequency-Band Energy"
assert out.attrs.data_type == "frequency_band_energy"

def test_invalid_frequency_range_raises(self, random_patch):
"""Ensure invalid filter ranges raise."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_transform/test_kurtosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_windowed_runs(self, random_patch):

assert out.dims == random_patch.dims
assert out.data.shape == random_patch.data.shape
assert out.attrs.data_type == "Kurtosis"
assert out.attrs.data_type == "kurtosis"
assert out.attrs.data_units is None

def test_recursive_runs(self, random_patch):
Expand All @@ -99,7 +99,7 @@ def test_recursive_runs(self, random_patch):

assert out.dims == random_patch.dims
assert out.data.shape == random_patch.data.shape
assert out.attrs.data_type == "Kurtosis"
assert out.attrs.data_type == "kurtosis"
assert out.attrs.data_units is None

def test_restores_original_dimension_order(self, random_patch):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_transform/test_stalta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_runs_time_dimension(self, random_patch):

assert out.dims == random_patch.dims
assert out.data.shape == random_patch.data.shape
assert out.attrs.data_type == "STALTA"
assert out.attrs.data_type == "stalta"
assert out.attrs.data_units is None

def test_runs_distance_dimension(self, random_patch):
Expand All @@ -24,7 +24,7 @@ def test_runs_distance_dimension(self, random_patch):

assert out.dims == random_patch.dims
assert out.data.shape == random_patch.data.shape
assert out.attrs.data_type == "STALTA"
assert out.attrs.data_type == "stalta"
assert out.attrs.data_units is None

def test_matches_expected_time_ratio(self, random_patch):
Expand All @@ -51,7 +51,7 @@ def test_attrs_are_set(self, random_patch):
"""Ensure output metadata are set."""
out = random_patch.stalta(time=(0.01, 0.05))

assert out.attrs.data_type == "STALTA"
assert out.attrs.data_type == "stalta"
assert out.attrs.data_units is None

def test_missing_dimension_kwargs_raise(self, random_patch):
Expand Down
58 changes: 58 additions & 0 deletions tests/test_utils/test_patch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,64 @@ def some_func(
with pytest.raises(pydantic.ValidationError):
some_func(patch, some_int=10, specific_float=20.0)

def test_data_type(self, random_patch):
"""Ensure the decorator can set the output data_type."""

@patch_function(data_type="strain_rate")
def some_func(patch):
"""A test function for setting the output data_type."""
return patch.new(data=patch.data + 1)

out = some_func(random_patch)

assert out.attrs.data_type == "strain_rate"

def test_data_type_overwrites_returned_patch_attr(self, random_patch):
"""Ensure the decorator data_type takes precedence."""

@patch_function(data_type="strain_rate")
def some_func(patch):
"""A test function with a conflicting output data_type."""
return patch.new(data=patch.data + 1, attrs={"data_type": "velocity"})

out = some_func(random_patch)

assert out.attrs.data_type == "strain_rate"

def test_data_type_none_preserves_existing_behavior(self, random_patch):
"""Ensure the default data_type argument leaves attrs unchanged."""

@patch_function()
def some_func(patch):
"""A test function without decorator-managed data_type."""
return patch.new(data=patch.data + 1, attrs={"data_type": "velocity"})

out = some_func(random_patch)

assert out.attrs.data_type == "velocity"

def test_data_type_and_history_use_one_attr_update(self, random_patch, monkeypatch):
"""Ensure decorator-managed attrs are updated together."""
update_count = 0
original_update_attrs = dc.Patch.update_attrs

def update_attrs(self, **attrs):
nonlocal update_count
update_count += 1
return original_update_attrs(self, **attrs)

@patch_function(data_type="strain_rate")
def some_func(patch):
"""A test function for setting data_type and history."""
return patch.new(data=patch.data + 1)

monkeypatch.setattr(dc.Patch, "update_attrs", update_attrs)
out = some_func(random_patch)

assert out.attrs.data_type == "strain_rate"
assert len(out.attrs.history) == len(random_patch.attrs.history) + 1
assert update_count == 1


class TestHistory:
"""Tests for tracking patch processing history."""
Expand Down
Loading