diff --git a/src/aimbat/config.py b/src/aimbat/config.py index 52bc89b..97d68f2 100644 --- a/src/aimbat/config.py +++ b/src/aimbat/config.py @@ -1,14 +1,16 @@ """Global configuration options for the AIMBAT application.""" -from pydantic import Field +from aimbat.lib._validators import EventParametersValidatorMixin +from pydantic import Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from pathlib import Path from datetime import timedelta from pysmo.tools.iccs._defaults import ICCS_DEFAULTS +from typing import Self import numpy as np -class Settings(BaseSettings): +class Settings(EventParametersValidatorMixin, BaseSettings): model_config = SettingsConfigDict(env_prefix="aimbat_", env_file=".env") project: Path = Field( @@ -18,10 +20,10 @@ class Settings(BaseSettings): """AIMBAT project file location.""" db_url: str = Field( - default_factory=lambda data: r"sqlite+pysqlite:///" + str(data["project"]), - description="AIMBAT database url (default value is derived from `project`.)", + default="", + description="AIMBAT database url (default value is derived from `project`).", ) - """AIMBAT database url.""" + """AIMBAT database url (default is derived from `project`).""" logfile: Path = Field(default=Path("aimbat.log"), description="Log file location.") """Log file location.""" @@ -58,6 +60,26 @@ class Settings(BaseSettings): ) """Initial minimum cross correlation coefficient.""" + bandpass_apply: bool = Field( + default=ICCS_DEFAULTS.bandpass_apply, + description="Whether to apply bandpass filter to seismograms.", + ) + """Whether to apply bandpass filter to seismograms.""" + + bandpass_fmin: float = Field( + default=ICCS_DEFAULTS.bandpass_fmin, + ge=0, + description="Minimum frequency for bandpass filter (ignored if `bandpass_apply` is False).", + ) + """Minimum frequency for bandpass filter (ignored if `bandpass_apply` is False).""" + + bandpass_fmax: float = Field( + default=ICCS_DEFAULTS.bandpass_fmax, + gt=0, + description="Maximum frequency for bandpass filter (ignored if `bandpass_apply` is False).", + ) + """Maximum frequency for bandpass filter (ignored if `bandpass_apply` is False).""" + sac_pick_header: str = Field( default="t0", description="SAC header field where initial pick is stored." ) @@ -80,6 +102,14 @@ class Settings(BaseSettings): ) """Minimum length of truncated UUID string.""" + @model_validator(mode="after") + def set_computed_defaults(self) -> Self: + """Sets defaults that depend on other fields.""" + # 1. Handle db_url dependency on project + if self.db_url == "": + self.db_url = f"sqlite+pysqlite:///{self.project}" + return self + settings = Settings() diff --git a/src/aimbat/lib/_validators.py b/src/aimbat/lib/_validators.py new file mode 100644 index 0000000..2ffd102 --- /dev/null +++ b/src/aimbat/lib/_validators.py @@ -0,0 +1,13 @@ +from typing import Self +from pydantic import BaseModel, model_validator + + +class EventParametersValidatorMixin(BaseModel): + bandpass_fmin: float + bandpass_fmax: float + + @model_validator(mode="after") + def check_freq_range(self) -> Self: + if self.bandpass_fmax <= self.bandpass_fmin: + raise ValueError("bandpass_fmax must be > bandpass_fmin") + return self diff --git a/src/aimbat/lib/iccs.py b/src/aimbat/lib/iccs.py index 97734a8..5042355 100644 --- a/src/aimbat/lib/iccs.py +++ b/src/aimbat/lib/iccs.py @@ -32,6 +32,9 @@ def create_iccs_instance(session: Session) -> ICCS: seismograms=active_event.seismograms, window_pre=active_event.parameters.window_pre, window_post=active_event.parameters.window_post, + bandpass_apply=active_event.parameters.bandpass_apply, + bandpass_fmin=active_event.parameters.bandpass_fmin, + bandpass_fmax=active_event.parameters.bandpass_fmax, min_ccnorm=active_event.parameters.min_ccnorm, context_width=settings.context_width, ) diff --git a/src/aimbat/lib/models.py b/src/aimbat/lib/models.py index 14569b2..cdafdfc 100644 --- a/src/aimbat/lib/models.py +++ b/src/aimbat/lib/models.py @@ -5,6 +5,7 @@ """ from aimbat.config import settings +from aimbat.lib._validators import EventParametersValidatorMixin from datetime import datetime, timedelta, timezone from sqlmodel import Relationship, SQLModel, Field from sqlalchemy.types import DateTime, TypeDecorator @@ -94,17 +95,30 @@ class AimbatEventParametersBase(SQLModel): completed: bool = False "Mark an event as completed." - min_ccnorm: float = Field(ge=0.0, le=1.0, default=settings.min_ccnorm) + min_ccnorm: float = Field( + ge=0.0, le=1.0, default_factory=lambda: settings.min_ccnorm + ) "Minimum cross-correlation used when automatically de-selecting seismograms." - window_pre: timedelta = Field(lt=0, default=settings.window_pre) + window_pre: timedelta = Field(lt=0, default_factory=lambda: settings.window_pre) "Pre-pick window length." - window_post: timedelta = Field(gt=0, default=settings.window_post) + window_post: timedelta = Field(gt=0, default_factory=lambda: settings.window_post) "Post-pick window length." + bandpass_apply: bool = Field(default_factory=lambda: settings.bandpass_apply) + "Whether to apply bandpass filter to seismograms." + + bandpass_fmin: float = Field(default_factory=lambda: settings.bandpass_fmin, ge=0) + "Minimum frequency for bandpass filter (ignored if `bandpass_apply` is False)." + + bandpass_fmax: float = Field(default_factory=lambda: settings.bandpass_fmax, gt=0) + "Maximum frequency for bandpass filter (ignored if `bandpass_apply` is False)." + -class AimbatEventParameters(AimbatEventParametersBase, table=True): +class AimbatEventParameters( + AimbatEventParametersBase, EventParametersValidatorMixin, table=True +): """Processing parameters common to all seismograms of a particular event.""" id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) diff --git a/src/aimbat/lib/typing.py b/src/aimbat/lib/typing.py index fb434ed..e3eed23 100644 --- a/src/aimbat/lib/typing.py +++ b/src/aimbat/lib/typing.py @@ -1,6 +1,6 @@ """Custom types used in AIMBAT.""" -from typing import Literal, TypeAlias +from typing import Literal from enum import StrEnum, auto @@ -15,15 +15,24 @@ class EventParameter(StrEnum): MIN_CCNORM = auto() WINDOW_PRE = auto() WINDOW_POST = auto() + BANDPASS_APPLY = auto() + BANDPASS_FMIN = auto() + BANDPASS_FMAX = auto() -EventParameterBool: TypeAlias = Literal[EventParameter.COMPLETED] +type EventParameterBool = Literal[ + EventParameter.COMPLETED, EventParameter.BANDPASS_APPLY +] "[`TypeAlias`][typing.TypeAlias] for [`AimbatEvent`][aimbat.lib.models.AimbatEvent] attributes with [`bool`][bool] values." -EventParameterFloat: TypeAlias = Literal[EventParameter.MIN_CCNORM] +type EventParameterFloat = Literal[ + EventParameter.MIN_CCNORM, + EventParameter.BANDPASS_FMIN, + EventParameter.BANDPASS_FMAX, +] "[`TypeAlias`][typing.TypeAlias] for [`AimbatEvent`][aimbat.lib.models.AimbatEvent] attributes with [`float`][float] values." -EventParameterTimedelta: TypeAlias = Literal[ +type EventParameterTimedelta = Literal[ EventParameter.WINDOW_PRE, EventParameter.WINDOW_POST ] "[`TypeAlias`][typing.TypeAlias] for [`AimbatEvent`][aimbat.lib.models.AimbatEvent] attributes with [`timedelta`][datetime.timedelta] values." @@ -41,7 +50,7 @@ class SeismogramParameter(StrEnum): T1 = auto() -SeismogramParameterBool: TypeAlias = Literal[ +type SeismogramParameterBool = Literal[ SeismogramParameter.SELECT, SeismogramParameter.FLIP ] -SeismogramParameterDatetime: TypeAlias = Literal[SeismogramParameter.T1] +type SeismogramParameterDatetime = Literal[SeismogramParameter.T1] diff --git a/tests/test_typing.py b/tests/test_typing.py index f576221..194740e 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1,6 +1,6 @@ from sqlmodel import SQLModel from enum import StrEnum -from typing import TypeAlias, get_args +from typing import get_args, TypeAliasType from aimbat.lib.models import AimbatEventParametersBase, AimbatSeismogramParametersBase from aimbat.lib.typing import ( EventParameter, @@ -13,21 +13,23 @@ ) -def set_from_basemodel(obj: SQLModel) -> set[str]: +def set_from_basemodel(obj: type[SQLModel]) -> set[str]: """Returns a set from the basemodel fields and remove "id" from it.""" - my_set = set(obj.model_fields) + my_set: set[str] = set(obj.model_fields) my_set.discard("id") + return my_set -def set_from_strenum(enum: StrEnum) -> set[str]: - return set([member.value for member in enum]) # type: ignore +def set_from_strenum(enum: type[StrEnum]) -> set[str]: + + return set([member.value for member in enum]) -def set_from_typealiases(*aliases: list[TypeAlias]) -> set[str]: +def set_from_typealiases(*aliases: TypeAliasType) -> set[str]: my_list = [] for alias in aliases: - my_list.extend([v for v in get_args(alias)]) + my_list.extend([v for v in get_args(alias.__value__)]) return set(my_list) @@ -36,20 +38,20 @@ class TestLibTypes: """Ensure Default models and types are consistent.""" def test_event_parameter_types(self) -> None: - assert set_from_basemodel(AimbatEventParametersBase) == set_from_strenum( # type: ignore - EventParameter # type: ignore + assert set_from_basemodel(AimbatEventParametersBase) == set_from_strenum( + EventParameter ) - assert set_from_strenum(EventParameter) == set_from_typealiases( # type: ignore - EventParameterBool, # type: ignore - EventParameterFloat, # type: ignore - EventParameterTimedelta, # type: ignore + assert set_from_strenum(EventParameter) == set_from_typealiases( + EventParameterBool, + EventParameterFloat, + EventParameterTimedelta, ) def test_seismogram_parameter_types(self) -> None: - assert set_from_basemodel(AimbatSeismogramParametersBase) == set_from_strenum( # type: ignore - SeismogramParameter # type: ignore + assert set_from_basemodel(AimbatSeismogramParametersBase) == set_from_strenum( + SeismogramParameter ) - assert set_from_strenum(SeismogramParameter) == set_from_typealiases( # type: ignore - SeismogramParameterBool, # type: ignore - SeismogramParameterDatetime, # type: ignore + assert set_from_strenum(SeismogramParameter) == set_from_typealiases( + SeismogramParameterBool, + SeismogramParameterDatetime, ) diff --git a/uv.lock b/uv.lock index ce183b9..3491c36 100644 --- a/uv.lock +++ b/uv.lock @@ -357,7 +357,7 @@ wheels = [ [[package]] name = "cyclopts" -version = "4.5.2" +version = "4.5.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, @@ -365,9 +365,9 @@ dependencies = [ { name = "rich" }, { name = "rich-rst" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/cd/1fd03921a95113182e6fdf84af5d47f07aa91c00c03ac074c192b0d4672c/cyclopts-4.5.2.tar.gz", hash = "sha256:7fe01b2d184c55c4555e06a0397602b319d87faa5b086b41913eaeaea52fae16", size = 162381, upload-time = "2026-02-11T16:30:46.051Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/16/06e35c217334930ff7c476ce1c8e74ed786fa3ef6742e59a1458e2412290/cyclopts-4.5.3.tar.gz", hash = "sha256:35fa70971204c450d9668646a6ca372eb5fa3070fbe8dd51c5b4b31e65198f2d", size = 162437, upload-time = "2026-02-16T15:07:11.96Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2b/03/f906829bcfcbb945f19d6a64240ffb66a31d69ca5533e95882f0efc9c13c/cyclopts-4.5.2-py3-none-any.whl", hash = "sha256:ee56ee23c2c81abc34b66b5aa8fd2698ca699740054e84e534449ec3eb7f944d", size = 200165, upload-time = "2026-02-11T16:30:46.942Z" }, + { url = "https://files.pythonhosted.org/packages/3a/1f/d8bce383a90d8a6a11033327777afa4d4d611ec11869284adb6f48152906/cyclopts-4.5.3-py3-none-any.whl", hash = "sha256:50af3085bb15d4a6f2582dd383dad5e4ba6a0d4d4c64ee63326d881a752a6919", size = 200231, upload-time = "2026-02-16T15:07:13.045Z" }, ] [[package]] @@ -399,11 +399,11 @@ wheels = [ [[package]] name = "filelock" -version = "3.24.0" +version = "3.24.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/00/cd/fa3ab025a8f9772e8a9146d8fd8eef6d62649274d231ca84249f54a0de4a/filelock-3.24.0.tar.gz", hash = "sha256:aeeab479339ddf463a1cdd1f15a6e6894db976071e5883efc94d22ed5139044b", size = 37166, upload-time = "2026-02-14T16:05:28.723Z" } +sdist = { url = "https://files.pythonhosted.org/packages/02/a8/dae62680be63cbb3ff87cfa2f51cf766269514ea5488479d42fec5aa6f3a/filelock-3.24.2.tar.gz", hash = "sha256:c22803117490f156e59fafce621f0550a7a853e2bbf4f87f112b11d469b6c81b", size = 37601, upload-time = "2026-02-16T02:50:45.614Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/dd/d7e7f4f49180e8591c9e1281d15ecf8e7f25eb2c829771d9682f1f9fe0c8/filelock-3.24.0-py3-none-any.whl", hash = "sha256:eebebb403d78363ef7be8e236b63cc6760b0004c7464dceaba3fd0afbd637ced", size = 23977, upload-time = "2026-02-14T16:05:27.578Z" }, + { url = "https://files.pythonhosted.org/packages/e7/04/a94ebfb4eaaa08db56725a40de2887e95de4e8641b9e902c311bfa00aa39/filelock-3.24.2-py3-none-any.whl", hash = "sha256:667d7dc0b7d1e1064dd5f8f8e80bdac157a6482e8d2e02cd16fd3b6b33bd6556", size = 24152, upload-time = "2026-02-16T02:50:44Z" }, ] [[package]] @@ -1220,11 +1220,11 @@ wheels = [ [[package]] name = "platformdirs" -version = "4.9.1" +version = "4.9.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/d5/763666321efaded11112de8b7a7f2273dd8d1e205168e73c334e54b0ab9a/platformdirs-4.9.1.tar.gz", hash = "sha256:f310f16e89c4e29117805d8328f7c10876eeff36c94eac879532812110f7d39f", size = 28392, upload-time = "2026-02-14T21:02:44.973Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/04/fea538adf7dbbd6d186f551d595961e564a3b6715bdf276b477460858672/platformdirs-4.9.2.tar.gz", hash = "sha256:9a33809944b9db043ad67ca0db94b14bf452cc6aeaac46a88ea55b26e2e9d291", size = 28394, upload-time = "2026-02-16T03:56:10.574Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/70/77/e8c95e95f1d4cdd88c90a96e31980df7e709e51059fac150046ad67fac63/platformdirs-4.9.1-py3-none-any.whl", hash = "sha256:61d8b967d34791c162d30d60737369cbbd77debad5b981c4bfda1842e71e0d66", size = 21307, upload-time = "2026-02-14T21:02:43.492Z" }, + { url = "https://files.pythonhosted.org/packages/48/31/05e764397056194206169869b50cf2fee4dbbbc71b344705b9c0d878d4d8/platformdirs-4.9.2-py3-none-any.whl", hash = "sha256:9170634f126f8efdae22fb58ae8a0eaa86f38365bc57897a6c4f781d1f5875bd", size = 21168, upload-time = "2026-02-16T03:56:08.891Z" }, ] [[package]] @@ -1485,8 +1485,8 @@ wheels = [ [[package]] name = "pysmo" -version = "1.0.0.dev1" -source = { git = "https://github.com/pysmo/pysmo?rev=master#d4742c7ef7d936bc355f70eecdf5587959aff571" } +version = "1.0.0.dev1+g0e21071c3" +source = { git = "https://github.com/pysmo/pysmo?rev=master#0e21071c386ac58561ee3460ac4e5fcaa8ba43c5" } dependencies = [ { name = "attrs" }, { name = "attrs-strict" }, @@ -1903,15 +1903,15 @@ wheels = [ [[package]] name = "sqlmodel" -version = "0.0.33" +version = "0.0.34" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "sqlalchemy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/62/22c287122598e61d07d005eec0b4eb97e6bde9a1b051bcd66c2bca846ea8/sqlmodel-0.0.33.tar.gz", hash = "sha256:b473544ed5fc2097894d89033049e569e1f138363dd3ec2ed4b6932cc9f29f5f", size = 95578, upload-time = "2026-02-11T15:23:39.504Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3b/6a/b1b26d589063e53a08c10a2d7bc624cba63dec045a312758d68f550a4ea1/sqlmodel-0.0.34.tar.gz", hash = "sha256:577e4aae1ba96ee5038e03d8b1404c642dad1a92e628988cdf4ce68d27abe982", size = 96236, upload-time = "2026-02-16T19:06:34.275Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/63/39/13891bae4658133b489a4d8b6a2f193d56110e392289560f312748e796dc/sqlmodel-0.0.33-py3-none-any.whl", hash = "sha256:9045bb4d97d2ba099c5a068ee9525af2d106972dda1ff8488e187ce50556bf73", size = 27444, upload-time = "2026-02-11T15:23:38.678Z" }, + { url = "https://files.pythonhosted.org/packages/eb/ee/1910f4eee41af4268b0d8cd688a05fb8ea23e9e6c64b8710592df24a8c66/sqlmodel-0.0.34-py3-none-any.whl", hash = "sha256:aeabc8f0de32076a0ed9216e88568459d737fca1e7133bfc6d1c657920789a2d", size = 27445, upload-time = "2026-02-16T19:06:35.709Z" }, ] [[package]]