diff --git a/.github/workflows/pre-commit-workflow.yaml b/.github/workflows/pre-commit-workflow.yaml index a484436..07d904d 100644 --- a/.github/workflows/pre-commit-workflow.yaml +++ b/.github/workflows/pre-commit-workflow.yaml @@ -12,18 +12,17 @@ on: jobs: pre_commit: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: - python-version: "3.9.0" + python-version: "3.12.0" - name: Install dependencies run: | - apt-get get update && apt-get install cmake - make install_precommit + pip install pre-commit - name: Pre-commit tests run: | - make run_precommit + pre-commit run -a diff --git a/.github/workflows/tests-workflow.yaml b/.github/workflows/tests-workflow.yaml index ca38326..ed3e3c2 100644 --- a/.github/workflows/tests-workflow.yaml +++ b/.github/workflows/tests-workflow.yaml @@ -1,24 +1,39 @@ name: Tests - on: pull_request: - branches: - - main + branches: [main] push: - branches: - - main - + branches: [main] jobs: tests: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 + env: + QT_QPA_PLATFORM: offscreen + LIBGL_ALWAYS_INDIRECT: "1" steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Install package - run: | - make docker_build - - name: Tests - run: | - make docker_tests \ No newline at end of file + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: System deps for PyQt5 (xcb/offscreen) + run: | + sudo apt-get update + sudo apt-get install -y \ + cmake \ + libgl1 libglib2.0-0 libsm6 libxext6 libxrender1 \ + libx11-6 libxcb1 libxkbcommon0 libxkbcommon-x11-0 \ + libfontconfig1 libfreetype6 + + - name: Install project + run: | + python -m pip install --upgrade pip + pip install -e . + + - name: Tests + run: make run_tests diff --git a/.github/workflows/tomls.yaml b/.github/workflows/tomls.yaml index ff1b372..2a3efa5 100644 --- a/.github/workflows/tomls.yaml +++ b/.github/workflows/tomls.yaml @@ -12,14 +12,14 @@ on: jobs: tests: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: - python-version: "3.8.0" + python-version: "3.12.0" - name: Install lib run: | pip install tomli diff --git a/.gitignore b/.gitignore index 03e120d..6018d04 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,4 @@ dmypy.json fb.onnx local_dev +/debug/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7379794..ef4e9d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,10 @@ default_language_version: - python: python3.9 + python: python3.12 files: \.py$|tests/.*\.py$ -default_stages: -- commit repos: # general hooks to verify or beautify code - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v6.0.0 hooks: - id: check-added-large-files args: [--maxkb=5000] @@ -23,7 +21,7 @@ repos: # autoformat code with black formatter - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 25.9.0 hooks: - id: black files: first_breaks|tests @@ -32,7 +30,7 @@ repos: # beautify and sort imports - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 6.1.0 hooks: - id: isort files: first_breaks|tests @@ -41,7 +39,7 @@ repos: # check code style - repo: https://github.com/pycqa/flake8 - rev: 3.8.4 + rev: 7.3.0 hooks: - id: flake8 files: first_breaks|tests @@ -50,7 +48,7 @@ repos: # static type checking - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.5.1 + rev: v1.18.2 hooks: - id: mypy files: first_breaks|tests diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 7b31b1e..0000000 --- a/Dockerfile +++ /dev/null @@ -1,27 +0,0 @@ -FROM ubuntu:20.04 - -ENV DEBIAN_FRONTEND=noninteractive -ENV LIBGL_ALWAYS_INDIRECT=1 -ENV QT_QPA_PLATFORM=offscreen - -RUN apt-get -y update -RUN apt-get install -y git wget cmake -RUN apt-get install -y git libsm6 libxext6 libfontconfig1 libxrender1 libgl1-mesa-glx libglib2.0-0 libgtk2.0 qt5-qmake - -RUN apt install -y python3.8 python3.8-distutils python3-pip -RUN ln -s /usr/bin/python3.8 /usr/bin/python \ - && ln -sf /usr/bin/python3.8 /usr/bin/python3 - -ENV LD_LIBRARY_PATH "/usr/local/lib:$LD_LIBRARY_PATH" -ENV LC_ALL=C.UTF-8 -ENV LANG=C.UTF-8 - -RUN pip install --upgrade pip - -WORKDIR /first-breaks-picking -COPY pyproject.toml . -RUN pip install .[dependencies] - -COPY . /first-breaks-picking -RUN pip install -e . - diff --git a/Makefile b/Makefile index c983ec0..87a9dd3 100644 --- a/Makefile +++ b/Makefile @@ -1,28 +1,11 @@ -IMAGE_NAME ?= first-breaks-picking:latest - - -.PHONY: install_precommit -install_precommit: - python -m pip install --upgrade pre-commit==3.5.0 - .PHONY: run_precommit run_precommit: - pre-commit install && pre-commit run -a + pre-commit run -a .PHONY: run_tests run_tests: pytest -sv --disable-warnings tests - -.PHONY: docker_build -docker_build: - DOCKER_BUILDKIT=1 docker build -t $(IMAGE_NAME) . - -.PHONY: docker_tests -docker_tests: docker_build - docker run -t $(IMAGE_NAME) make run_tests - - .PHONY: build_wheel build_wheel: python -m pip install --upgrade pip diff --git a/README.md b/README.md index bcc9442..b54d53b 100644 --- a/README.md +++ b/README.md @@ -476,7 +476,7 @@ picks_ms = np.random.uniform(low=0, size=sgy.ntr) picks = Picks(values=picks_ms, unit="ms", dt_mcs=sgy.dt_mcs, color=(0, 100, 100)) export_image(sgy, image_filename, - picks=picks) + picks_list=[picks]) ``` [code-block-end]:plot-sgy-custom-picks diff --git a/first_breaks/_pytorch/picker_torch.py b/first_breaks/_pytorch/picker_torch.py index 504b5d8..c12993b 100644 --- a/first_breaks/_pytorch/picker_torch.py +++ b/first_breaks/_pytorch/picker_torch.py @@ -80,7 +80,7 @@ def change_settings( # type: ignore device: Optional[str] = None, segmentation_hw: Optional[Tuple[int, int]] = None, num_workers: Optional[int] = None, - batch_size: Optional[int] = None + batch_size: Optional[int] = None, ) -> None: if args: raise ValueError("Use named arguments instead of positional") diff --git a/first_breaks/const.py b/first_breaks/const.py index 6966ea6..cc507f4 100644 --- a/first_breaks/const.py +++ b/first_breaks/const.py @@ -28,6 +28,10 @@ def get_cache_folder() -> Path: MODEL_ONNX_PATH = CACHE_FOLDER / "fb.onnx" MODEL_ONNX_URL = "https://oml.daloroserver.com/download/seis/fb.onnx" MODEL_ONNX_HASH = "7e39e017b01325180e36885eccaeb17a" +MODEL_ONNX_HASHES = [ + MODEL_ONNX_HASH, + "afc03594f49b88ea61b5cf6ba8245be4", # model with heatmap +] TIMEOUT = 60 diff --git a/first_breaks/data_models/independent.py b/first_breaks/data_models/independent.py index f164a14..163652e 100644 --- a/first_breaks/data_models/independent.py +++ b/first_breaks/data_models/independent.py @@ -5,7 +5,7 @@ import numpy as np from pydantic import UUID4, BaseModel, Field, field_validator -TColor = Union[Tuple[int, int, int, int], Tuple[int, int, int]] +TColor = Union[Tuple[int, int, int], Tuple[int, int, int, int], Tuple[int, ...]] TNormalize = Union[Literal["trace", "gather"], float, int, np.ndarray, None] diff --git a/first_breaks/desktop/combobox_with_mapping.py b/first_breaks/desktop/combobox_with_mapping.py index 542da90..81cf200 100644 --- a/first_breaks/desktop/combobox_with_mapping.py +++ b/first_breaks/desktop/combobox_with_mapping.py @@ -30,7 +30,7 @@ def __init__( current_label: Optional[str] = None, current_value: Optional[str] = None, *args: Any, - **kwargs: Any + **kwargs: Any, ): super().__init__(*args, **kwargs) diff --git a/first_breaks/desktop/graph.py b/first_breaks/desktop/graph.py index 7944a65..d55d9a7 100644 --- a/first_breaks/desktop/graph.py +++ b/first_breaks/desktop/graph.py @@ -2,7 +2,7 @@ import os import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import pyqtgraph as pg @@ -423,7 +423,7 @@ def export( fill_black: Optional[str] = DEFAULTS.fill_black, time_window: Optional[Tuple[float, float]] = None, traces_window: Optional[Tuple[float, float]] = None, - picks: Optional[Picks] = None, + picks_list: Optional[Sequence[Picks]] = None, task: Optional[Task] = None, show_processing_region: bool = True, contour_color: TColor = DEFAULTS.region_contour_color, @@ -443,9 +443,6 @@ def export( if args: raise need_kwargs_exception - if picks is not None and task is not None: - raise ValueError("'picks' and 'task' are mutually exclusive. Use only one of them or none") - if width is None: if traces_window is None: num_traces = sgy.num_traces @@ -467,10 +464,12 @@ def export( refresh_view=True, ) - if task: + if task and task.picks is not None: self.plot_picks(task.picks) - elif picks is not None: - self.plot_picks(picks) + if picks_list: + assert all(len(picks) == sgy.num_traces for picks in picks_list) + for picks in picks_list: + self.plot_picks(picks) if task is not None and show_processing_region: self.plot_processing_region( @@ -537,7 +536,7 @@ def export_image( fill_black: Optional[str] = DEFAULTS.fill_black, time_window: Optional[Tuple[float, float]] = None, traces_window: Optional[Tuple[float, float]] = None, - picks: Optional[Picks] = None, + picks_list: Optional[Sequence[Picks]] = None, show_processing_region: bool = True, contour_color: TColor = DEFAULTS.region_contour_color, poly_color: TColor = DEFAULTS.region_poly_color, @@ -586,7 +585,7 @@ def export_image( fill_black=fill_black, time_window=time_window, traces_window=traces_window, - picks=picks, + picks_list=picks_list, task=task, show_processing_region=show_processing_region, contour_color=contour_color, diff --git a/first_breaks/desktop/main_gui.py b/first_breaks/desktop/main_gui.py index d46498c..a81497a 100644 --- a/first_breaks/desktop/main_gui.py +++ b/first_breaks/desktop/main_gui.py @@ -1,7 +1,7 @@ import sys import warnings from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from PyQt5.QtCore import QSize, Qt, QThreadPool, pyqtSignal from PyQt5.QtGui import QCloseEvent @@ -19,7 +19,12 @@ QWidget, ) -from first_breaks.const import DEMO_SGY_PATH, HIGH_DPI, MODEL_ONNX_HASH, MODEL_ONNX_PATH +from first_breaks.const import ( + DEMO_SGY_PATH, + HIGH_DPI, + MODEL_ONNX_HASHES, + MODEL_ONNX_PATH, +) from first_breaks.data_models.independent import ExceptionOptional from first_breaks.desktop.graph import GraphWidget from first_breaks.desktop.last_folder_manager import last_folder_manager @@ -48,11 +53,11 @@ class FileState: file_changed = 2 @classmethod - def get_file_state(cls, fname: Union[str, Path], fhash: str) -> int: + def get_file_state(cls, fname: Union[str, Path], fhashes: List[str]) -> int: if not Path(fname).is_file(): return cls.file_not_exists else: - return cls.valid_file if calc_hash(fname) == fhash else cls.file_changed + return cls.valid_file if calc_hash(fname) in fhashes else cls.file_changed class ReadyToProcess: @@ -195,7 +200,7 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore self.settings: Optional[Dict[str, Any]] = None self.last_folder: Optional[Union[str, Path]] = None self.picks_from_file_in_ms: Optional[Tuple[Union[int, float], ...]] = None - self.picker_hash = MODEL_ONNX_HASH + self.picker_hashes = MODEL_ONNX_HASHES if show: self.show() @@ -255,7 +260,10 @@ def run_processing_region(self) -> None: def show_processing_region(self) -> None: for picks in self.picks_manager.picks_mapping.values(): if picks.created_by_nn and picks.active: - tps, max_time = picks.picking_parameters.traces_per_gather, picks.picking_parameters.maximum_time + tps, max_time = ( + picks.picking_parameters.traces_per_gather, + picks.picking_parameters.maximum_time, + ) self.graph.plot_processing_region(tps, max_time) break @@ -290,11 +298,14 @@ def load_nn(self, filename: Optional[Union[str, Path]] = None) -> None: if not filename: options = QFileDialog.Options() filename, _ = QFileDialog.getOpenFileName( - self, "Select file with NN weights", directory=last_folder_manager.get_last_folder(), options=options + self, + "Select file with NN weights", + directory=last_folder_manager.get_last_folder(), + options=options, ) if filename: - if FileState.get_file_state(filename, self.picker_hash) == FileState.valid_file: + if FileState.get_file_state(filename, self.picker_hashes) == FileState.valid_file: self.nn_manager.init_net(weights=filename) self.button_load_nn.setEnabled(False) self.ready_to_process.model_loaded = True diff --git a/first_breaks/desktop/picks_manager_widget.py b/first_breaks/desktop/picks_manager_widget.py index 7a9f9dc..1e002c3 100644 --- a/first_breaks/desktop/picks_manager_widget.py +++ b/first_breaks/desktop/picks_manager_widget.py @@ -303,7 +303,7 @@ def __init__(self) -> None: self.remove_button = QPushButton("-", self) self.remove_button.clicked.connect(self.remove_items) # renamed for clarity # self.properties_button = QPushButton("\u2699", self) # Unicode character for gear - self.properties_button = QPushButton("\U0001F4BE", self) # "\U0001F4BE" or "\U0001F5AB" + self.properties_button = QPushButton("\U0001f4be", self) # "\U0001F4BE" or "\U0001F5AB" self.properties_button.setFont(self.font()) # to increase the size of the button a bit self.properties_button.clicked.connect(self.open_properties) button_layout.addWidget(self.add_button) diff --git a/first_breaks/desktop/radioset_widget.py b/first_breaks/desktop/radioset_widget.py index aa21ff9..e73e029 100644 --- a/first_breaks/desktop/radioset_widget.py +++ b/first_breaks/desktop/radioset_widget.py @@ -30,7 +30,7 @@ def __init__( orientation: str = "horizontal", margins: Optional[int] = None, *args: Any, - **kwargs: Any + **kwargs: Any, ): super().__init__(*args, **kwargs) assert orientation in ["horizontal", "vertical"] diff --git a/first_breaks/picking/picker_onnx.py b/first_breaks/picking/picker_onnx.py index ad13ff8..e80a68e 100644 --- a/first_breaks/picking/picker_onnx.py +++ b/first_breaks/picking/picker_onnx.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, Generator, Optional, Tuple, Union +from typing import Any, Dict, Generator, Optional, Union import numpy as np import onnxruntime as ort @@ -75,6 +75,10 @@ def get_batch_generator(self, batch_size: int = 1) -> Generator[Dict[str, np.nda class PickerONNX(IPicker): + OUTPUT_PICKS_KEY = "picks" + OUTPUT_CONFS_KEY = "confs" + OUTPUT_HEATMAP_KEY = "heatmap" + def __init__( self, model_path: Optional[Union[str, Path]] = None, @@ -96,6 +100,17 @@ def __init__( self.model: Optional[ort.InferenceSession] = None self.init_model() + self._available_outputs = sorted(o.name for o in self.model.get_outputs()) # type: ignore + self._input_name = self.model.get_inputs()[0].name # type: ignore + + for mandatory_key in [self.OUTPUT_PICKS_KEY, self.OUTPUT_CONFS_KEY]: + assert ( + mandatory_key in self._available_outputs + ), f"`mandatory_key` not found in model outputs `{self._available_outputs}`" + + def is_heatmap_available(self) -> bool: + return self.OUTPUT_HEATMAP_KEY in self._available_outputs + def init_model(self) -> None: sess_opt = ort.SessionOptions() sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL @@ -124,16 +139,21 @@ def change_settings( # type: ignore return self - def pick_batch_of_gathers(self, gather: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def pick_batch_of_gathers(self, gather: np.ndarray) -> Dict[str, np.ndarray]: assert gather.ndim == 4 assert all(dim > 0 for dim in gather.shape) - outputs = self.model.run(None, {"input": gather}) - return outputs[0], outputs[1] + outputs_list = self.model.run(self._available_outputs, {self._input_name: gather}) + return dict(zip(self._available_outputs, outputs_list)) def process_task(self, task: Task) -> Task: task_picks_in_sample = np.zeros(task.sgy.num_traces) task_confidence = np.zeros(task.sgy.num_traces) + if self.is_heatmap_available(): + task_heatmap = np.zeros((task.sgy.num_samples, task.sgy.num_traces)) + else: + task_heatmap = None + task_iterator = IteratorOfTask(task) counter_step_finished = 0 self.callback_processing_started(len(task_iterator)) @@ -141,11 +161,21 @@ def process_task(self, task: Task) -> Task: for idx, batch in enumerate(task_iterator.get_batch_generator(batch_size=self.batch_size)): self.interrupt_if_need() data = batch["gather"].astype(np.float32) - picks, confidence = self.pick_batch_of_gathers(data) + results = self.pick_batch_of_gathers(data) + + picks = results[self.OUTPUT_PICKS_KEY] + confidence = results[self.OUTPUT_CONFS_KEY] indices = batch["gather_ids"] - task_picks_in_sample[indices.flatten()] = picks.flatten() - task_confidence[indices.flatten()] = confidence.flatten() + + task_picks_in_sample[indices.flatten()] = picks.flatten() # (B, W) -> (BxW) + task_confidence[indices.flatten()] = confidence.flatten() # (B, W) -> (BxW) + + if self.is_heatmap_available(): + heatmap = results[self.OUTPUT_HEATMAP_KEY] + heatmap = heatmap.transpose(1, 0, 2) # (B, H, W) -> (H, B, W) + heatmap = heatmap.reshape(heatmap.shape[0], -1) # (H, B, W) -> (H, BxW) + task_heatmap[: len(heatmap), indices.flatten()] = heatmap counter_step_finished += len(data) self.callback_step_finished(counter_step_finished) @@ -154,13 +184,14 @@ def process_task(self, task: Task) -> Task: picks = Picks( values=task_picks_in_sample.astype(int).tolist(), + heatmap=task_heatmap, unit="sample", dt_mcs=task.sgy.dt_mcs, confidence=task_confidence.tolist(), modified_manually=False, created_manually=False, created_by_nn=True, - picks_color=generate_color(), + color=generate_color(), picking_parameters=task.picking_parameters, ) diff --git a/first_breaks/picking/picks.py b/first_breaks/picking/picks.py index b157b1e..a5ab6e5 100644 --- a/first_breaks/picking/picks.py +++ b/first_breaks/picking/picks.py @@ -1,5 +1,5 @@ import uuid -from typing import List, Literal, Optional, Union +from typing import Annotated, List, Literal, Optional, Union import numpy as np from pydantic import UUID4, Field, model_validator @@ -43,12 +43,13 @@ class Picks(DefaultModel): dt_mcs: Optional[float] = None confidence: Optional[Union[np.ndarray, List[Union[int, float]]]] = None + heatmap: Optional[np.ndarray] = None created_by_nn: Optional[bool] = None created_manually: Optional[bool] = None modified_manually: Optional[bool] = None picking_parameters: Optional[PickingParameters] = None color: TColor = Field(default_factory=generate_color, description="Color for picks") # type: ignore - width: float = Field(DEFAULT_PICKS_WIDTH, description="Width of pick line") + width: Annotated[float, Field(description="Width of pick line")] = DEFAULT_PICKS_WIDTH active: Optional[bool] = None @@ -152,15 +153,17 @@ def from_samples(self, values: TValues) -> None: def create_duplicate(self, keep_color: bool = False) -> "Picks": values = self.values.copy() confidence = self.confidence.copy() if self.confidence is not None else None + heatmap = self.heatmap.copy() if self.heatmap is not None else None return Picks( values=values, confidence=confidence, + heatmap=heatmap, dt_mcs=self.dt_mcs, unit=self.unit, created_manually=True, created_by_nn=self.created_by_nn, modified_manually=True, picking_parameters=self.picking_parameters, - picks_color=self.color if keep_color else generate_color(), + color=self.color if keep_color else generate_color(), ) diff --git a/first_breaks/picking/refiner.py b/first_breaks/picking/refiner.py new file mode 100644 index 0000000..23d7ca2 --- /dev/null +++ b/first_breaks/picking/refiner.py @@ -0,0 +1,160 @@ +from typing import Dict, Tuple + +import numpy as np + +from first_breaks.picking.picks import Picks +from first_breaks.sgy.reader import SGY +from first_breaks.utils.filtering import apply_savgol_filter + + +class Refiner: + def refine(self, sgy: SGY, picks: Picks) -> Picks: + raise NotImplementedError + + +def find_extrema_mask(data: np.ndarray, neighbor_range: int = 1) -> np.ndarray: + assert data.ndim == 2 + + maxima_mask = np.ones(data.shape, dtype=bool) + minima_mask = np.ones(data.shape, dtype=bool) + ids = np.arange(len(data)) + + for shift in range(1, neighbor_range + 1): + shifted_data_left = data.take(ids + shift, mode="clip", axis=0) + shifted_data_right = data.take(ids - shift, mode="clip", axis=0) + + maxima_mask &= (data > shifted_data_left) & (data > shifted_data_right) + minima_mask &= (data < shifted_data_left) & (data < shifted_data_right) + + extrema_mask = maxima_mask | minima_mask + + return extrema_mask + + +def calc_intersection(data: np.ndarray, data_derivative: np.ndarray, tangent_points: np.ndarray) -> np.ndarray: + slope = data_derivative[tangent_points] + intercept = data[tangent_points] - slope * tangent_points + return -intercept / slope + + +def calc_intersection_vectorized( + data: np.ndarray, data_derivative: np.ndarray, extrema_mask: np.ndarray +) -> Dict[int, Tuple[np.ndarray, ...]]: + assert all(arr.ndim == 2 for arr in [data, data_derivative, extrema_mask]) + assert extrema_mask.dtype == np.bool_ + + extrema_indices = np.where(extrema_mask) + row_indices, col_indices = extrema_indices + + sorted_indices = np.lexsort((row_indices, col_indices)) + sorted_row_indices = row_indices[sorted_indices] + sorted_col_indices = col_indices[sorted_indices] + + slope = data_derivative[sorted_row_indices, sorted_col_indices] + intercept = data[sorted_row_indices, sorted_col_indices] - slope * sorted_row_indices + intersection = -intercept / slope + + to_keep = (intersection >= 0) & (intersection < (len(data) - 1)) & (intersection != np.inf) + + intersection = intersection[to_keep] + sorted_col_indices = sorted_col_indices[to_keep] + + unique_cols, start_indices = np.unique(sorted_col_indices, return_index=True) + intersections = dict(zip(unique_cols, np.split(intersection, start_indices[1:]))) + + return intersections + + +def get_band_mask( + data: np.ndarray, band_ids: np.ndarray, width_before: int, width_after: int +) -> Tuple[np.ndarray, np.ndarray]: + num_rows, num_cols = data.shape + row_indices = np.arange(-width_before, width_after + 1).reshape(-1, 1) + band_ids + row_indices_clipped = np.clip(row_indices, 0, num_rows - 1) + return row_indices_clipped, np.arange(num_cols) + + +def refine_picks( # type: ignore + raw_picks: np.ndarray, + probability_heatmap: np.ndarray, + traces2intersections, + minimum_probability_to_refine: float = 0.9, +) -> np.ndarray: + refined_picks = raw_picks.copy() + + for trace, intersections in traces2intersections.items(): + intersections_int = intersections.astype(int) + prob = probability_heatmap[intersections_int, trace] + best_candidate = np.argmax(prob) + if prob[best_candidate] > minimum_probability_to_refine: + refined_picks[trace] = intersections_int[best_candidate] + + return refined_picks + + +class MinimalPhaseRefiner(Refiner): + def __init__( + self, + analyse_window_before: int = 5, + analyse_window_after: int = 15, + smooth_window: int = 11, + smooth_polyorder: int = 3, + extrema_window: int = 3, + min_probability_to_refine: float = 0.9, + ): + self.analyse_window_before = analyse_window_before + self.analyse_window_after = analyse_window_after + self.smooth_window = smooth_window + self.smooth_polyorder = smooth_polyorder + self.extrema_window = extrema_window + self.min_probability_to_refine = min_probability_to_refine + + def refine(self, sgy: SGY, picks: Picks) -> Picks: + assert picks.heatmap is not None + # intersection point doesn't depend on scale of data, so we don't process raw data + raw = sgy.read() + picks_in_samples = picks.picks_in_samples.copy() + + filtered = apply_savgol_filter( + data=raw, + polyorder=self.smooth_polyorder, + window_length=self.smooth_window, + deriv=0, + ) + first_derivateive = apply_savgol_filter( + data=raw, + polyorder=self.smooth_polyorder, + window_length=self.smooth_window, + deriv=1, + ) + + band_mask = get_band_mask( + data=raw, + band_ids=picks_in_samples, + width_before=self.analyse_window_before, + width_after=self.analyse_window_after, + ) + + extrema = find_extrema_mask(data=first_derivateive[band_mask], neighbor_range=self.extrema_window) + + tr2intersections = calc_intersection_vectorized( + data=filtered[band_mask], + data_derivative=-first_derivateive[band_mask], + extrema_mask=extrema, + ) + + # previous intersections obtained based on band data. We need to map band intersections to initail coordintates + tr2intersections = { + tr: band_mask[0][inter.round().astype(int), tr] for tr, inter in tr2intersections.items() # type: ignore + } + + refined_picks = refine_picks( + raw_picks=picks_in_samples, + probability_heatmap=picks.heatmap, + traces2intersections=tr2intersections, + minimum_probability_to_refine=self.min_probability_to_refine, + ) + + picks.from_samples(refined_picks) + + return picks diff --git a/first_breaks/utils/filtering.py b/first_breaks/utils/filtering.py new file mode 100644 index 0000000..bc7b2b1 --- /dev/null +++ b/first_breaks/utils/filtering.py @@ -0,0 +1,39 @@ +import numpy as np + + +def savgol_coeffs(window_length: int, polyorder: int, deriv: int = 0) -> np.ndarray: + assert window_length % 2 == 1, "The value must be odd" + assert polyorder % 2 == 1, "The value must be odd" + assert polyorder < window_length + + half_window = (window_length - 1) // 2 + a = np.zeros((window_length, polyorder + 1)) + for i in range(window_length): + for j in range(polyorder + 1): + a[i, j] = (i - half_window) ** j + atr_a = np.dot(a.T, a) + atr = np.linalg.pinv(atr_a) + b = np.dot(atr, a.T) + coeffs = b[deriv] + return coeffs + + +def apply_savgol_filter(data: np.ndarray, window_length: int, polyorder: int, deriv: int = 0) -> np.ndarray: + """ + https://en.wikipedia.org/wiki/Savitzky%E2%80%93Golay_filter + """ + assert 1 <= data.ndim <= 2 + coeffs = savgol_coeffs(window_length, polyorder, deriv) + + half_window = (window_length - 1) // 2 + pad_mode = "reflect" + + if data.ndim == 1: + padding = (half_window, half_window) + else: + padding = ((half_window, half_window), (0, 0)) # type: ignore + + padded_data = np.pad(data, padding, mode=pad_mode) + + filtered_data = np.apply_along_axis(lambda m: np.convolve(m, coeffs, mode="valid"), axis=0, arr=padded_data) + return filtered_data diff --git a/first_breaks/utils/utils.py b/first_breaks/utils/utils.py index 1f1cb60..3ce284d 100644 --- a/first_breaks/utils/utils.py +++ b/first_breaks/utils/utils.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd import requests +from tqdm.auto import tqdm from first_breaks.const import ( DEMO_SGY_HASH, @@ -34,8 +35,10 @@ def chunk_iterable(it: Iterable[Any], size: int) -> List[Tuple[Any, ...]]: return list(iter(lambda: tuple(islice(it, size)), ())) -def get_io(source: Union[Path, str, bytes], mode: str = "r") -> Union[io.BytesIO, io.FileIO]: - if isinstance(source, (Path, str)): +def get_io(source: Union[Path, str, bytes, io.BytesIO, io.FileIO], mode: str = "r") -> Union[io.BytesIO, io.FileIO]: + if isinstance(source, (io.BytesIO, io.FileIO)): + return source + elif isinstance(source, (Path, str)): source = Path(source).resolve() if "r" in mode: if not source.exists(): @@ -50,24 +53,42 @@ def get_io(source: Union[Path, str, bytes], mode: str = "r") -> Union[io.BytesIO def calc_hash(source: Union[Path, str, bytes, io.BytesIO, io.FileIO]) -> str: hash_md5 = hashlib.md5() - if not isinstance(source, (io.BytesIO, io.FileIO)): - source = get_io(source, mode="rb") + source = get_io(source, mode="rb") + source.seek(0) for chunk in iter(lambda: source.read(4096), b""): # type: ignore hash_md5.update(chunk) + source.close() return hash_md5.hexdigest() -def download_by_url(url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT) -> Optional[bytes]: - response = requests.get(url, timeout=timeout) - if response.status_code != 200: - response.raise_for_status() - return None - else: +def download_by_url(url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT) -> bytes: + response = requests.get(url, stream=True, timeout=timeout) + response.raise_for_status() + total_size = int(response.headers.get("content-length", 0)) + block_size = 1024 + buffer = io.BytesIO() + + with tqdm( + desc=f"Downloading {url}", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in response.iter_content(block_size): + buffer.write(data) + bar.update(len(data)) + + buffer.seek(0) + content = buffer.getvalue() + if fname: Path(fname).parent.mkdir(exist_ok=True, parents=True) with open(fname, "wb+") as f: - f.write(response.content) - return response.content + f.write(content) + bar.set_description(f"File {url} saved to '{Path(fname).resolve()}'") + + return content def download_and_validate_file( @@ -82,13 +103,17 @@ def download_and_validate_file( def download_demo_sgy( - fname: Union[str, Path] = DEMO_SGY_PATH, url: str = DEMO_SGY_URL, md5: str = DEMO_SGY_HASH + fname: Union[str, Path] = DEMO_SGY_PATH, + url: str = DEMO_SGY_URL, + md5: str = DEMO_SGY_HASH, ) -> Union[str, Path]: return download_and_validate_file(fname=fname, url=url, md5=md5) def download_model_onnx( - fname: Union[str, Path] = MODEL_ONNX_PATH, url: str = MODEL_ONNX_URL, md5: str = MODEL_ONNX_HASH + fname: Union[str, Path] = MODEL_ONNX_PATH, + url: str = MODEL_ONNX_URL, + md5: str = MODEL_ONNX_HASH, ) -> Union[str, Path]: return download_and_validate_file(fname=fname, url=url, md5=md5) @@ -153,19 +178,19 @@ def remove_unused_kwargs(kwargs: Dict[str, Any], constructor: Any) -> Dict[str, return {k: v for k, v in kwargs.items() if k in inspect.signature(constructor).parameters} -def _color_generator() -> Generator[List[int], None, None]: +def _color_generator() -> Generator[Tuple[int, ...], None, None]: golden_ratio = 0.618033988749895 hue = random.random() # start from a random position while True: hue += golden_ratio hue %= 1 - yield [int(255 * v) for v in colorsys.hsv_to_rgb(hue, 0.5, 0.95)] + yield tuple(int(255 * v) for v in colorsys.hsv_to_rgb(hue, 0.5, 0.95)) cgen = _color_generator() -def generate_color() -> List[int]: +def generate_color() -> Tuple[int, ...]: return next(cgen) diff --git a/first_breaks/utils/visualizations.py b/first_breaks/utils/visualizations.py deleted file mode 100644 index 60afcc4..0000000 --- a/first_breaks/utils/visualizations.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Optional, Tuple, Union - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.patches import Polygon - - -def plotseis( - data: np.ndarray, - picking: Optional[np.ndarray] = None, - add_picking: Optional[np.ndarray] = None, - normalizing: Optional[Union[str, int]] = "indiv", - clip: float = 0.9, - ampl: float = 1.0, - patch: bool = True, - colorseis: bool = False, - wiggle: bool = True, - background: Optional[np.ndarray] = None, - colorbar: bool = False, - dt: float = 1.0, - show: bool = True, - dpi: int = 300, - figsize: Tuple[int, int] = (5, 5), -) -> matplotlib.figure.Figure: - - num_time, num_trace = np.shape(data) - - if normalizing == "indiv": - norm_factor = np.mean(np.abs(data), axis=0) - norm_factor[np.abs(norm_factor) < 1e-9 * np.max(np.abs(norm_factor))] = 1 - elif normalizing == "entire": - norm_factor = np.tile(np.mean(np.abs(data)), (1, num_trace)) - elif np.size(normalizing) == 1 and normalizing is not None: - norm_factor = np.tile(normalizing, (1, num_trace)) - elif np.size(normalizing) == num_trace: - norm_factor = np.reshape(normalizing, (1, num_trace)) - elif normalizing is None: - norm_factor = np.ones(data.shape[1]) - else: - raise ValueError('Wrong value of "normalizing"') - - data = data / norm_factor * ampl - - mask_overflow = np.abs(data) > clip - data[mask_overflow] = np.sign(data[mask_overflow]) * clip - - data_time = np.tile((np.arange(num_time) + 1)[:, np.newaxis], (1, num_trace)) * dt - - fig, ax = plt.subplots(figsize=figsize) - fig.set_dpi(dpi) - - plt.xlim((0, num_trace + 1)) - plt.ylim((0, num_time * dt)) - ax.invert_yaxis() - ax.xaxis.tick_top() - - if wiggle: - data_to_wiggle = data + (np.arange(num_trace) + 1)[np.newaxis, :] - - ax.plot(data_to_wiggle, data_time, color=(0, 0, 0)) - - if colorseis: - if not (wiggle or patch): - ax.imshow( - data, - aspect="auto", - interpolation="bilinear", - alpha=1, - extent=(1, num_trace, (num_time - 0.5) * dt, -0.5 * dt), - cmap="gray", - ) - else: - ax.imshow( - data, - aspect="auto", - interpolation="bilinear", - alpha=1, - extent=(-0.5, num_trace + 2 - 0.5, (num_time - 0.5) * dt, -0.5 * dt), - cmap="gray", - ) - - if patch: - data_to_patch = data - data_to_patch[data_to_patch < 0] = 0 - - for k_trace in range(num_trace): - patch_data = ( - (data_to_patch[:, k_trace] + k_trace + 1)[:, np.newaxis], - data_time[:, k_trace][:, np.newaxis], - ) - patch_data = np.hstack(patch_data) - - head = np.array((k_trace + 1, 0))[np.newaxis, :] - tail = np.array((k_trace + 1, num_time * dt))[np.newaxis, :] - patch_data = np.vstack((head, patch_data, tail)) - - polygon = Polygon(patch_data, closed=True, facecolor="black", edgecolor=None) - ax.add_patch(polygon) - - if picking is not None: - picking = np.array(picking) - ax.plot(np.arange(num_trace) + 1, picking * dt, linewidth=1, color="blue") - - if add_picking is not None: - add_picking = np.array(add_picking) - ax.plot(np.arange(num_trace) + 1, add_picking * dt, linewidth=1, color="green") - - if background is not None: - bg = ax.imshow( - background, - aspect="auto", - extent=(0.5, num_trace + 1 - 0.5, (num_time - 0.5) * dt, -0.5 * dt), - cmap="YlOrRd", - ) - - if colorbar: - plt.colorbar(mappable=bg) - - if show: - plt.show() - return fig diff --git a/pyproject.toml b/pyproject.toml index 5abd745..0b3959f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Project is devoted to pick waves that are the first to be detected on a seismogram with neural network" readme = {file = "README.md", content-type = "text/markdown"} -requires-python = ">=3.8" +requires-python = ">=3.12" authors = [ {name = "Aleksei Tarasov", email = "aleksei.v.tarasov@gmail.com"}, {name = "Aleksei Tarasov"}, @@ -33,15 +33,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Image Recognition", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ] entry-points = {console_scripts = {first-breaks-picking = "first_breaks.cli:cli_commands"}} dependencies = [ "requests>=2.28.2", - "numpy>=1.24.2", + "numpy (>=1.24.2,<2.0.0)", "pandas>=2.0.0", "PyQt5>=5.15.9", "pyqtgraph>=0.13.3", diff --git a/pyproject_gpu.toml b/pyproject_gpu.toml index 93b3d8b..369f79d 100644 --- a/pyproject_gpu.toml +++ b/pyproject_gpu.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Project is devoted to pick waves that are the first to be detected on a seismogram with neural network (CUDA accelerated)" readme = {file = "README.md", content-type = "text/markdown"} -requires-python = ">=3.8" +requires-python = ">=3.12" authors = [ {name = "Aleksei Tarasov", email = "aleksei.v.tarasov@gmail.com"}, {name = "Aleksei Tarasov"}, @@ -33,15 +33,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Image Recognition", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ] entry-points = {console_scripts = {first-breaks-picking = "first_breaks.cli:cli_commands"}} dependencies = [ "requests>=2.28.2", - "numpy>=1.24.2", + "numpy (>=1.24.2,<2.0.0)", "pandas>=2.0.0", "PyQt5>=5.15.9", "pyqtgraph>=0.13.3", diff --git a/tests/test_common/test_readme_examples.py b/tests/test_common/test_readme_examples.py index f4c89df..140b3ed 100644 --- a/tests/test_common/test_readme_examples.py +++ b/tests/test_common/test_readme_examples.py @@ -49,7 +49,7 @@ def test_code_blocks_in_readme(block_name: str, demo_sgy: Path, logs_dir_for_tes code = find_code_block(PROJECT_ROOT / "README.md", start_indicator, end_indicator) assert code - tmp_fname = "tmp.py" + tmp_fname = f"{block_name}.py" with open(tmp_fname, "w") as f: f.write(code)