Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/aimbat/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Global configuration options for the AIMBAT application."""

from pydantic import Field
from aimbat.lib._validators import EventParametersValidatorMixin
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pathlib import Path
from datetime import timedelta
from pysmo.tools.iccs._defaults import ICCS_DEFAULTS
from typing import Self
import numpy as np


class Settings(BaseSettings):
class Settings(EventParametersValidatorMixin, BaseSettings):
model_config = SettingsConfigDict(env_prefix="aimbat_", env_file=".env")

project: Path = Field(
Expand All @@ -18,10 +20,10 @@ class Settings(BaseSettings):
"""AIMBAT project file location."""

db_url: str = Field(
default_factory=lambda data: r"sqlite+pysqlite:///" + str(data["project"]),
description="AIMBAT database url (default value is derived from `project`.)",
default="",
description="AIMBAT database url (default value is derived from `project`).",
)
"""AIMBAT database url."""
"""AIMBAT database url (default is derived from `project`)."""

logfile: Path = Field(default=Path("aimbat.log"), description="Log file location.")
"""Log file location."""
Expand Down Expand Up @@ -58,6 +60,26 @@ class Settings(BaseSettings):
)
"""Initial minimum cross correlation coefficient."""

bandpass_apply: bool = Field(
default=ICCS_DEFAULTS.bandpass_apply,
description="Whether to apply bandpass filter to seismograms.",
)
"""Whether to apply bandpass filter to seismograms."""

bandpass_fmin: float = Field(
default=ICCS_DEFAULTS.bandpass_fmin,
ge=0,
description="Minimum frequency for bandpass filter (ignored if `bandpass_apply` is False).",
)
"""Minimum frequency for bandpass filter (ignored if `bandpass_apply` is False)."""

bandpass_fmax: float = Field(
default=ICCS_DEFAULTS.bandpass_fmax,
gt=0,
description="Maximum frequency for bandpass filter (ignored if `bandpass_apply` is False).",
)
"""Maximum frequency for bandpass filter (ignored if `bandpass_apply` is False)."""

sac_pick_header: str = Field(
default="t0", description="SAC header field where initial pick is stored."
)
Expand All @@ -80,6 +102,14 @@ class Settings(BaseSettings):
)
"""Minimum length of truncated UUID string."""

@model_validator(mode="after")
def set_computed_defaults(self) -> Self:
"""Sets defaults that depend on other fields."""
# 1. Handle db_url dependency on project
if self.db_url == "":
self.db_url = f"sqlite+pysqlite:///{self.project}"
return self


settings = Settings()

Expand Down
13 changes: 13 additions & 0 deletions src/aimbat/lib/_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Self
from pydantic import BaseModel, model_validator


class EventParametersValidatorMixin(BaseModel):
bandpass_fmin: float
bandpass_fmax: float

@model_validator(mode="after")
def check_freq_range(self) -> Self:
if self.bandpass_fmax <= self.bandpass_fmin:
raise ValueError("bandpass_fmax must be > bandpass_fmin")
Comment thread
smlloyd marked this conversation as resolved.
return self
3 changes: 3 additions & 0 deletions src/aimbat/lib/iccs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def create_iccs_instance(session: Session) -> ICCS:
seismograms=active_event.seismograms,
window_pre=active_event.parameters.window_pre,
window_post=active_event.parameters.window_post,
bandpass_apply=active_event.parameters.bandpass_apply,
bandpass_fmin=active_event.parameters.bandpass_fmin,
bandpass_fmax=active_event.parameters.bandpass_fmax,
min_ccnorm=active_event.parameters.min_ccnorm,
context_width=settings.context_width,
)
Expand Down
22 changes: 18 additions & 4 deletions src/aimbat/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from aimbat.config import settings
from aimbat.lib._validators import EventParametersValidatorMixin
from datetime import datetime, timedelta, timezone
from sqlmodel import Relationship, SQLModel, Field
from sqlalchemy.types import DateTime, TypeDecorator
Expand Down Expand Up @@ -94,17 +95,30 @@ class AimbatEventParametersBase(SQLModel):
completed: bool = False
"Mark an event as completed."

min_ccnorm: float = Field(ge=0.0, le=1.0, default=settings.min_ccnorm)
min_ccnorm: float = Field(
ge=0.0, le=1.0, default_factory=lambda: settings.min_ccnorm
)
"Minimum cross-correlation used when automatically de-selecting seismograms."

window_pre: timedelta = Field(lt=0, default=settings.window_pre)
window_pre: timedelta = Field(lt=0, default_factory=lambda: settings.window_pre)
"Pre-pick window length."

window_post: timedelta = Field(gt=0, default=settings.window_post)
window_post: timedelta = Field(gt=0, default_factory=lambda: settings.window_post)
"Post-pick window length."

bandpass_apply: bool = Field(default_factory=lambda: settings.bandpass_apply)
"Whether to apply bandpass filter to seismograms."

bandpass_fmin: float = Field(default_factory=lambda: settings.bandpass_fmin, ge=0)
"Minimum frequency for bandpass filter (ignored if `bandpass_apply` is False)."

bandpass_fmax: float = Field(default_factory=lambda: settings.bandpass_fmax, gt=0)
"Maximum frequency for bandpass filter (ignored if `bandpass_apply` is False)."


class AimbatEventParameters(AimbatEventParametersBase, table=True):
class AimbatEventParameters(
AimbatEventParametersBase, EventParametersValidatorMixin, table=True
):
"""Processing parameters common to all seismograms of a particular event."""

id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
Expand Down
21 changes: 15 additions & 6 deletions src/aimbat/lib/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Custom types used in AIMBAT."""

from typing import Literal, TypeAlias
from typing import Literal
from enum import StrEnum, auto


Expand All @@ -15,15 +15,24 @@ class EventParameter(StrEnum):
MIN_CCNORM = auto()
WINDOW_PRE = auto()
WINDOW_POST = auto()
BANDPASS_APPLY = auto()
BANDPASS_FMIN = auto()
BANDPASS_FMAX = auto()


EventParameterBool: TypeAlias = Literal[EventParameter.COMPLETED]
type EventParameterBool = Literal[
EventParameter.COMPLETED, EventParameter.BANDPASS_APPLY
]
"[`TypeAlias`][typing.TypeAlias] for [`AimbatEvent`][aimbat.lib.models.AimbatEvent] attributes with [`bool`][bool] values."

EventParameterFloat: TypeAlias = Literal[EventParameter.MIN_CCNORM]
type EventParameterFloat = Literal[
EventParameter.MIN_CCNORM,
EventParameter.BANDPASS_FMIN,
EventParameter.BANDPASS_FMAX,
]
"[`TypeAlias`][typing.TypeAlias] for [`AimbatEvent`][aimbat.lib.models.AimbatEvent] attributes with [`float`][float] values."

EventParameterTimedelta: TypeAlias = Literal[
type EventParameterTimedelta = Literal[
EventParameter.WINDOW_PRE, EventParameter.WINDOW_POST
]
"[`TypeAlias`][typing.TypeAlias] for [`AimbatEvent`][aimbat.lib.models.AimbatEvent] attributes with [`timedelta`][datetime.timedelta] values."
Expand All @@ -41,7 +50,7 @@ class SeismogramParameter(StrEnum):
T1 = auto()


SeismogramParameterBool: TypeAlias = Literal[
type SeismogramParameterBool = Literal[
SeismogramParameter.SELECT, SeismogramParameter.FLIP
]
SeismogramParameterDatetime: TypeAlias = Literal[SeismogramParameter.T1]
type SeismogramParameterDatetime = Literal[SeismogramParameter.T1]
38 changes: 20 additions & 18 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlmodel import SQLModel
from enum import StrEnum
from typing import TypeAlias, get_args
from typing import get_args, TypeAliasType
from aimbat.lib.models import AimbatEventParametersBase, AimbatSeismogramParametersBase
from aimbat.lib.typing import (
EventParameter,
Expand All @@ -13,21 +13,23 @@
)


def set_from_basemodel(obj: SQLModel) -> set[str]:
def set_from_basemodel(obj: type[SQLModel]) -> set[str]:
"""Returns a set from the basemodel fields and remove "id" from it."""
my_set = set(obj.model_fields)
my_set: set[str] = set(obj.model_fields)
my_set.discard("id")

return my_set


def set_from_strenum(enum: StrEnum) -> set[str]:
return set([member.value for member in enum]) # type: ignore
def set_from_strenum(enum: type[StrEnum]) -> set[str]:

return set([member.value for member in enum])


def set_from_typealiases(*aliases: list[TypeAlias]) -> set[str]:
def set_from_typealiases(*aliases: TypeAliasType) -> set[str]:
my_list = []
for alias in aliases:
my_list.extend([v for v in get_args(alias)])
my_list.extend([v for v in get_args(alias.__value__)])

return set(my_list)

Expand All @@ -36,20 +38,20 @@ class TestLibTypes:
"""Ensure Default models and types are consistent."""

def test_event_parameter_types(self) -> None:
assert set_from_basemodel(AimbatEventParametersBase) == set_from_strenum( # type: ignore
EventParameter # type: ignore
assert set_from_basemodel(AimbatEventParametersBase) == set_from_strenum(
EventParameter
)
assert set_from_strenum(EventParameter) == set_from_typealiases( # type: ignore
EventParameterBool, # type: ignore
EventParameterFloat, # type: ignore
EventParameterTimedelta, # type: ignore
assert set_from_strenum(EventParameter) == set_from_typealiases(
EventParameterBool,
EventParameterFloat,
EventParameterTimedelta,
)

def test_seismogram_parameter_types(self) -> None:
assert set_from_basemodel(AimbatSeismogramParametersBase) == set_from_strenum( # type: ignore
SeismogramParameter # type: ignore
assert set_from_basemodel(AimbatSeismogramParametersBase) == set_from_strenum(
SeismogramParameter
)
assert set_from_strenum(SeismogramParameter) == set_from_typealiases( # type: ignore
SeismogramParameterBool, # type: ignore
SeismogramParameterDatetime, # type: ignore
assert set_from_strenum(SeismogramParameter) == set_from_typealiases(
SeismogramParameterBool,
SeismogramParameterDatetime,
)
28 changes: 14 additions & 14 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading