From e22e453c88afe8b2ee7afa327b04a15df694fd69 Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Sat, 22 Jun 2024 18:12:16 +0200 Subject: [PATCH 1/9] refiner --- dev_refi.py | 247 ++++++++++++++ first_breaks/benchmark.py | 98 ++++++ first_breaks/const.py | 6 +- first_breaks/desktop/graph.py | 64 ++-- first_breaks/desktop/main_gui.py | 97 ++++-- first_breaks/desktop/picks_manager_widget.py | 69 +++- first_breaks/picking/picker_onnx.py | 38 ++- first_breaks/picking/picks.py | 8 +- first_breaks/picking/refiner.py | 324 +++++++++++++++++++ first_breaks/utils/filtering.py | 40 +++ first_breaks/utils/utils.py | 89 ++++- first_breaks/utils/visualizations.py | 44 +-- 12 files changed, 1016 insertions(+), 108 deletions(-) create mode 100644 dev_refi.py create mode 100644 first_breaks/benchmark.py create mode 100644 first_breaks/picking/refiner.py create mode 100644 first_breaks/utils/filtering.py diff --git a/dev_refi.py b/dev_refi.py new file mode 100644 index 0000000..7d74676 --- /dev/null +++ b/dev_refi.py @@ -0,0 +1,247 @@ +import numpy as np +import matplotlib.pyplot as plt + +from first_breaks.exports.export_picks import export_to_sgy +from first_breaks.picking.picker_onnx import PickerONNX +from first_breaks.picking.picks import Picks +from first_breaks.picking.task import Task +from first_breaks.picking.utils import preprocess_gather +from first_breaks.sgy.reader import SGY +from first_breaks.utils.utils import download_demo_sgy, generate_color +from first_breaks.utils.visualizations import plotseis + + +def savgol_coeffs(window_length, polyorder, deriv=0): + half_window = (window_length - 1) // 2 + A = np.zeros((window_length, polyorder + 1)) + for i in range(window_length): + for j in range(polyorder + 1): + A[i, j] = (i - half_window) ** j + ATA = np.dot(A.T, A) + AT = np.linalg.pinv(ATA) + B = np.dot(AT, A.T) + coeffs = B[deriv] + return coeffs + + +def apply_savgol_filter(data, window_length, polyorder, deriv=0): + coeffs = savgol_coeffs(window_length, polyorder, deriv) + + print(coeffs) + half_window = (window_length - 1) // 2 + pad_mode = "reflect" + + padded_data = np.pad(data, ((half_window, half_window), (0, 0)), mode=pad_mode) + + filtered_data = np.apply_along_axis( + lambda m: np.convolve(m, coeffs, mode="valid"), axis=0, arr=padded_data + ) + return filtered_data + + +# # Example seismogram (replace with your data) +# seismogram = np.array([ +# # Your seismic data here +# ]) +# +# # Parameters for Savitzky-Golay filter +# window_length = 11 # Choose an appropriate window length (must be odd) +# polyorder = 3 # Polynomial order +# +# # Apply Savitzky-Golay filter to smooth the data along the N dimension +# smoothed_seismogram = apply_savgol_filter(seismogram, window_length, polyorder) +# +# # Compute the first derivative using Savitzky-Golay filter +# first_derivative_seismogram = apply_savgol_filter(seismogram, window_length, polyorder, deriv=1) +# +# # Example initial picks from neural network (replace with your actual picks) +# initial_picks = np.array([/* your initial picks here */]) +# +# # Refine the picks using the tangent point and intersection method +# refined_picks = [] +# for trace in range(seismogram.shape[0]): +# trace_picks = [] +# for pick in initial_picks[trace]: +# tangent_point, intersection = find_tangent_and_intersection(seismogram[trace], first_derivative_seismogram[trace]) +# trace_picks.append(intersection) +# refined_picks.append(trace_picks) +# +# # Convert refined_picks to a 2D array +# refined_picks = np.array(refined_picks) +# +# # Visualize results for a specific trace +# trace_idx = 0 # Change as needed +# plt.figure(figsize=(12, 6)) +# plt.plot(seismogram[trace_idx], label='Original Seismic Trace') +# plt.plot(smoothed_seismogram[trace_idx], label='Smoothed Seismic Trace', linestyle='dashed') +# plt.plot(first_derivative_seismogram[trace_idx], label='First Derivative', linestyle='dotted') +# plt.scatter(initial_picks[trace_idx], seismogram[trace_idx][initial_picks[trace_idx]], color='red', label='Initial Picks') +# plt.scatter(refined_picks[trace_idx], seismogram[trace_idx][refined_picks[trace_idx].astype(int)], color='green', label='Refined Picks', marker='x') +# plt.legend() +# plt.xlabel('Sample Index') +# plt.ylabel('Amplitude') +# plt.title('Seismic Trace and Picks Refinement') +# plt.show() + + +# sgy = SGY(download_demo_sgy()) +# picker = PickerONNX(model_path="fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx", batch_size=3) +# task = Task(source=sgy, traces_per_gather=12, maximum_time=100) +# task = picker.process_task(task) +# +# plt.imshow(task.picks.heatmap) +# plt.show() +# print(task.picks.picks_in_mcs) +# sgy = SGY("with_picks.sgy") +# picks = Picks( +# values=sgy.read_custom_trace_header(236, "i"), unit="mcs", dt_mcs=sgy.dt_mcs +# ) +# print(picks.picks_in_mcs) +# print(task.picks.heatmap.max(), task.picks.heatmap.min()) +# +# np.save("heatmap.npy", task.picks.heatmap) +# +# +# assert False +# export_to_sgy(sgy, filename="with_picks.sgy", picks=task.picks) + + +def find_candidates(first_derivative, window=None, neighbor_range=1): + """Find all extrema of the first derivative within the window using extended neighborhood checks.""" + if window is None: + start = 0 + end = len(first_derivative) + else: + start, end = window + + segment = first_derivative[start:end, ...] + + # Initialize masks for maxima and minima + maxima_mask = np.ones(segment.shape, dtype=bool) + minima_mask = np.ones(segment.shape, dtype=bool) + ids = np.arange(len(segment)) + + for shift in range(1, neighbor_range + 1): + shifted_segment_left = segment.take(ids + shift, mode="clip", axis=0) + shifted_segment_right = segment.take(ids - shift, mode="clip", axis=0) + + maxima_mask &= (segment > shifted_segment_left) & ( + segment > shifted_segment_right + ) + minima_mask &= (segment < shifted_segment_left) & ( + segment < shifted_segment_right + ) + + extrema_mask = maxima_mask | minima_mask + + extrema = np.where(extrema_mask)[0] + start + return extrema + + +def calculate_intersections(smoothed_trace, first_derivative, cand_points): + """Calculate the intersection of the tangent line at the tangent point with the time axis.""" + slope = first_derivative[cand_points] + intercept = smoothed_trace[cand_points] - slope * cand_points + return -intercept / slope + + # if slope != 0: + # intersection = -intercept / slope + # else: + # intersection = cand_points + # + # return intersection + + # slope = first_derivative[cand_points] + # intercept = smoothed_trace[cand_points] - slope * cand_points + # return intercept + + +sgy = SGY("with_picks.sgy") +picks = Picks( + values=sgy.read_custom_trace_header(236, "i"), unit="mcs", dt_mcs=sgy.dt_mcs +) + + +# plotseis(sgy.read(max_sample=500), picking=picks.picks_in_samples, normalizing="trace") + +data = -preprocess_gather(sgy.read(), clip=3, gain=1, normalize="trace") +heatmap = np.load("heatmap.npy") + +# trace_idx = 40 +# start_ms = 25 +# end_ms = 42 + +# trace_idx = 34 +# start_ms = 25 +# end_ms = 42 + +# trace_idx = 10 +# start_ms = 40 +# end_ms = 80 + +trace_idx = 93 +start_ms = 40 +end_ms = 60 + +start_sample = sgy.units_converter.ms2index(start_ms) +end_sample = sgy.units_converter.ms2index(end_ms) +# start_sample = 60 +# end_sample = 170 +sub = data[start_sample:end_sample, trace_idx] +# +plt.plot(sub, color="b") +plt.plot( + [ + picks.picks_in_samples[trace_idx] - start_sample, + picks.picks_in_samples[trace_idx] - start_sample, + ], + [min(sub), max(sub)], +) +# plt.show() + +window = 11 +order = 3 + +smoothed = apply_savgol_filter( + sub[:, None], window_length=window, polyorder=order, deriv=0 +) +plt.plot(smoothed[:, 0], color="k") + +first_deriv = apply_savgol_filter( + sub[:, None], window_length=window, polyorder=order, deriv=1 +) +plt.plot((2 * first_deriv[:, 0]), color="r") +cand = find_candidates(first_deriv[:, 0], neighbor_range=3) + +refined = calculate_intersection(smoothed[:, 0], -first_deriv[:, 0], cand) + +mask_refined = (refined < len(sub)) & (refined > 0) +refined = refined[mask_refined] +cand = cand[mask_refined] +print(cand.astype(int)) +print(refined.astype(int)) +print( + (100 * heatmap[start_sample + refined.astype(int), trace_idx]).astype(int), + "confidence", +) +print(picks.picks_in_samples[trace_idx] - start_sample) + +for i in range(len(cand)): + color_p = [c / 255 for c in generate_color()] + plt.scatter([cand[i]], [0], color=color_p) + plt.scatter([refined[i]], [0], marker="*", color=color_p) + +# second_deriv = apply_savgol_filter( +# sub[:, None], window_length=window, polyorder=order, deriv=2 +# ) + +plt.plot(heatmap[start_sample:end_sample, trace_idx], linestyle="--") + +# plt.plot( +# 3 * heatmap[start_sample:end_sample, trace_idx] * first_deriv[:, 0], +# linestyle="dashdot", +# ) + +# plt.plot(np.abs(4 * second_deriv[:, 0]), color="g") +plt.grid() +plt.show() diff --git a/first_breaks/benchmark.py b/first_breaks/benchmark.py new file mode 100644 index 0000000..3bc9135 --- /dev/null +++ b/first_breaks/benchmark.py @@ -0,0 +1,98 @@ +from pathlib import Path +from typing import Union, Literal + +from first_breaks.const import FIRST_BYTE +from first_breaks.desktop.graph import export_image +from first_breaks.picking.picker_onnx import PickerONNX +from first_breaks.picking.picks import Picks +from first_breaks.picking.task import Task +from first_breaks.sgy.reader import SGY +from first_breaks.utils.utils import ( + download_demo_sgy, + download_by_url, + download_and_validate_file, + calc_hash, +) +import numpy as np + +fname = download_demo_sgy() +gain = 1 +traces_per_gather = 12 +maximum_time = 100 + + +# def predict_picks( +# sgy: SGY, +# picker: PickerONNX, +# gain: float, +# traces_per_gather: int, +# maximum_time: float, +# ): +# task = Task( +# source=sgy, +# traces_per_gather=traces_per_gather, +# maximum_time=maximum_time, +# gain=gain, +# ) +# task = picker.process_task(task) +# picks = task.get_result() + + +def make_report( + filename: Union[str, Path], + sgy: SGY, + saved_picks_byte_position: int, + predicted_picks: Picks, + gain: float, + traces_per_gather: int, + maximum_time: float, +): + assert 1 <= saved_picks_byte_position <= 237 + saved_picks = Picks( + values=sgy.read_custom_trace_header(saved_picks_byte_position - FIRST_BYTE, "i"), + unit="mcs", + dt_mcs=sgy.dt_mcs, + ) + to_export = {} + + # ANONYMOUS DATA BLOCK + + # difference between manual picks and predicted picks is anonymous and expose nothing + difference = (np.array(saved_picks.picks_in_mcs) - np.array(predicted_picks.picks_in_mcs)).astype(int).tolist() + to_export["difference"] = difference + + # hash of traces allows me to inderstand if reports were created based on same data or different + # without direct access to file: if 2 reports have same `traces_hash` it means that they were calculated based on + # same traces, if not - files were different. + # So I can understand how different parameters affect specific file analysing `difference` metric for several + # files belongs to same `traces_hash` + traces_hash = calc_hash(sgy.read().tobytes(order="C")) + to_export["traces_hash"] = traces_hash + + channel_hashes = sgy.traces_headers["CHAN"].apply(lambda x: calc_hash(x)[:10]) + source_hashes = sgy.traces_headers["CHAN"].apply(lambda x: calc_hash(x)[:10]) + + # basic shape, not anonymsed + num_samples, num_traces = sgy.shape + + +def download_model_with_heatmap(destination: Union[str, Path]): + model_hash = "afc03594f49b88ea61b5cf6ba8245be4" + model_url = "https://oml.daloroserver.com/download/seis/fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx" + download_and_validate_file(url=model_url, md5=model_hash, fname=destination) + + +if __name__ == "__main__": + destination = "model_with_heatmap.onnx" + download_model_with_heatmap(destination) + + # fname = Path(fname) + # assert fname.exists(), f"File {fname.resolve()} not found" + # + # sgy = SGY(fname) + # + # print( + # f"SGY: fname='{fname.resolve()}', num_traces={sgy.num_traces}, num_samples={sgy.num_samples}, dt_mcs={sgy.dt_mcs}" + # ) + # + # picker = PickerONNX() diff --git a/first_breaks/const.py b/first_breaks/const.py index 6966ea6..9a571b4 100644 --- a/first_breaks/const.py +++ b/first_breaks/const.py @@ -8,7 +8,10 @@ def get_cache_folder() -> Path: if is_linux(): - return Path(environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) / "first_breaks_picking" + return ( + Path(environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) + / "first_breaks_picking" + ) elif is_macos(): return Path.home() / "Library" / "Caches" / "first_breaks_picking" elif is_windows(): @@ -28,6 +31,7 @@ def get_cache_folder() -> Path: MODEL_ONNX_PATH = CACHE_FOLDER / "fb.onnx" MODEL_ONNX_URL = "https://oml.daloroserver.com/download/seis/fb.onnx" MODEL_ONNX_HASH = "7e39e017b01325180e36885eccaeb17a" +MODEL_ONNX_HASHES = [MODEL_ONNX_HASH, "afc03594f49b88ea61b5cf6ba8245be4"] TIMEOUT = 60 diff --git a/first_breaks/desktop/graph.py b/first_breaks/desktop/graph.py index 7944a65..80be383 100644 --- a/first_breaks/desktop/graph.py +++ b/first_breaks/desktop/graph.py @@ -2,7 +2,7 @@ import os import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Sequence import numpy as np import pyqtgraph as pg @@ -60,8 +60,12 @@ def __init__(self, use_open_gl: bool = True, *args: Any, **kwargs: Any): self.setup_axes() self.spectrum_roi_manager = RoiManager(viewbox=self.getViewBox()) - self.spectrum_window = SpectrumWindow(use_open_gl=use_open_gl, roi_manager=self.spectrum_roi_manager) - self.mouse_click_signal = pg.SignalProxy(self.sceneObj.sigMouseClicked, rateLimit=60, slot=self.mouse_clicked) + self.spectrum_window = SpectrumWindow( + use_open_gl=use_open_gl, roi_manager=self.spectrum_roi_manager + ) + self.mouse_click_signal = pg.SignalProxy( + self.sceneObj.sigMouseClicked, rateLimit=60, slot=self.mouse_clicked + ) def resolve_postime2xy(self, position: Any, time: Any) -> Tuple[Any, Any]: return postime2xy(vsp_view=self.vsp_view, position=position, time=time) @@ -172,12 +176,16 @@ def plotseis( self.getPlotItem().setXRange(min=0, max=x_max) for idx in range(num_traces): - self._plot_trace_fast(trace=traces[:, idx], time=t, shift=idx + 1, fill_black=fill_black) + self._plot_trace_fast( + trace=traces[:, idx], time=t, shift=idx + 1, fill_black=fill_black + ) self.pos_ax.showLabel() self.graph_updated_signal.emit() - def _plot_trace_fast(self, trace: np.ndarray, time: np.ndarray, shift: int, fill_black: Optional[str]) -> None: + def _plot_trace_fast( + self, trace: np.ndarray, time: np.ndarray, shift: int, fill_black: Optional[str] + ) -> None: connect = np.ones(len(time), dtype=np.int32) connect[-1] = 0 @@ -203,7 +211,9 @@ def _plot_trace_fast(self, trace: np.ndarray, time: np.ndarray, shift: int, fill sign = -1 if fill_black == "left" else 1 x, y = self.resolve_postime2xy(shift, time[0]) - w, h = self.resolve_postime2xy(sign * 1.1 * max(np.abs(shifted_trace)), time[-1]) + w, h = self.resolve_postime2xy( + sign * 1.1 * max(np.abs(shifted_trace)), time[-1] + ) rect = QPainterPath() rect.addRect(x, y, w, h) @@ -211,7 +221,9 @@ def _plot_trace_fast(self, trace: np.ndarray, time: np.ndarray, shift: int, fill patch = path.intersected(rect) item = pg.QtWidgets.QGraphicsPathItem(patch) - pen = QPen(QColor(255, 255, 255, 0), 1, Qt.SolidLine, Qt.FlatCap, Qt.MiterJoin) + pen = QPen( + QColor(255, 255, 255, 0), 1, Qt.SolidLine, Qt.FlatCap, Qt.MiterJoin + ) pen.setCosmetic(True) # 1 pixel width for any scale and resolution item.setPen(pen) item.setBrush(Qt.black) @@ -229,7 +241,9 @@ def _replace_tick_labels(self, *args: Any, **kwargs: Any) -> List[str]: if v % 1 == 0: v = int(v) - 1 if 0 <= v < self.sgy.num_traces: - labels_from_headers.append(str(self.sgy.traces_headers[self.pos_ax_header].iloc[v])) + labels_from_headers.append( + str(self.sgy.traces_headers[self.pos_ax_header].iloc[v]) + ) else: labels_from_headers.append("") else: @@ -275,7 +289,9 @@ def plot_processing_region( # Vertical lines line_time = np.array([0, region_start_time]) - for idx in np.arange(traces_per_gather + 0.5, num_traces - 1, traces_per_gather): + for idx in np.arange( + traces_per_gather + 0.5, num_traces - 1, traces_per_gather + ): line_pos = np.array([idx, idx]) line_x, line_y = self.resolve_postime2xy(line_pos, line_time) line_path = pg.arrayToQPath(line_x, line_y, np.ones(2, dtype=np.int32)) @@ -286,7 +302,9 @@ def plot_processing_region( # Transparent polygon on ignored part poly_pos = np.array([-2, num_traces + 2, num_traces + 2, -2]) - poly_time = np.array([region_start_time, region_start_time, sgy_end_time, sgy_end_time]) + poly_time = np.array( + [region_start_time, region_start_time, sgy_end_time, sgy_end_time] + ) poly_x, poly_y = self.resolve_postime2xy(poly_pos, poly_time) poly_path = pg.arrayToQPath(poly_x, poly_y, np.ones(4, dtype=np.int32)) poly_item = pg.QtWidgets.QGraphicsPathItem(poly_path) @@ -296,7 +314,9 @@ def plot_processing_region( self.addItem(poly_item) def _get_picks_as_item(self, picks: Picks) -> pg.PlotCurveItem: - x, y = self.resolve_postime2xy(np.arange(self.sgy.num_traces) + 1, np.array(picks.picks_in_ms)) + x, y = self.resolve_postime2xy( + np.arange(self.sgy.num_traces) + 1, np.array(picks.picks_in_ms) + ) line = pg.PlotCurveItem() line.setData(x, y) @@ -423,7 +443,7 @@ def export( fill_black: Optional[str] = DEFAULTS.fill_black, time_window: Optional[Tuple[float, float]] = None, traces_window: Optional[Tuple[float, float]] = None, - picks: Optional[Picks] = None, + picks_list: Optional[Sequence[Picks]] = None, task: Optional[Task] = None, show_processing_region: bool = True, contour_color: TColor = DEFAULTS.region_contour_color, @@ -443,15 +463,15 @@ def export( if args: raise need_kwargs_exception - if picks is not None and task is not None: - raise ValueError("'picks' and 'task' are mutually exclusive. Use only one of them or none") - if width is None: if traces_window is None: num_traces = sgy.num_traces width = int(width_per_trace * num_traces) + headers_total_pixels else: - width = int(width_per_trace * (traces_window[1] - traces_window[0])) + headers_total_pixels + width = ( + int(width_per_trace * (traces_window[1] - traces_window[0])) + + headers_total_pixels + ) self.avoid_memory_bomb(height, width) @@ -467,10 +487,12 @@ def export( refresh_view=True, ) - if task: + if task and task.picks is not None: self.plot_picks(task.picks) - elif picks is not None: - self.plot_picks(picks) + if picks_list: + assert all(len(picks) == sgy.num_traces for picks in picks_list) + for picks in picks_list: + self.plot_picks(picks) if task is not None and show_processing_region: self.plot_processing_region( @@ -537,7 +559,7 @@ def export_image( fill_black: Optional[str] = DEFAULTS.fill_black, time_window: Optional[Tuple[float, float]] = None, traces_window: Optional[Tuple[float, float]] = None, - picks: Optional[Picks] = None, + picks_list: Optional[Sequence[Picks]] = None, show_processing_region: bool = True, contour_color: TColor = DEFAULTS.region_contour_color, poly_color: TColor = DEFAULTS.region_poly_color, @@ -586,7 +608,7 @@ def export_image( fill_black=fill_black, time_window=time_window, traces_window=traces_window, - picks=picks, + picks_list=picks_list, task=task, show_processing_region=show_processing_region, contour_color=contour_color, diff --git a/first_breaks/desktop/main_gui.py b/first_breaks/desktop/main_gui.py index d46498c..6dffe0b 100644 --- a/first_breaks/desktop/main_gui.py +++ b/first_breaks/desktop/main_gui.py @@ -19,7 +19,13 @@ QWidget, ) -from first_breaks.const import DEMO_SGY_PATH, HIGH_DPI, MODEL_ONNX_HASH, MODEL_ONNX_PATH +from first_breaks.const import ( + DEMO_SGY_PATH, + HIGH_DPI, + MODEL_ONNX_HASH, + MODEL_ONNX_PATH, + MODEL_ONNX_HASHES, +) from first_breaks.data_models.independent import ExceptionOptional from first_breaks.desktop.graph import GraphWidget from first_breaks.desktop.last_folder_manager import last_folder_manager @@ -32,6 +38,7 @@ ) from first_breaks.desktop.utils import MessageBox, set_geometry from first_breaks.picking.task import Task +from first_breaks.picking.refiner import MinimalPhaseRefiner from first_breaks.sgy.reader import SGY from first_breaks.utils.utils import calc_hash, download_demo_sgy, download_model_onnx @@ -48,11 +55,12 @@ class FileState: file_changed = 2 @classmethod - def get_file_state(cls, fname: Union[str, Path], fhash: str) -> int: + def get_file_state(cls, fname: Union[str, Path], fhashes: str) -> int: if not Path(fname).is_file(): return cls.file_not_exists else: - return cls.valid_file if calc_hash(fname) == fhash else cls.file_changed + print(calc_hash(fname), fhashes) + return cls.valid_file if calc_hash(fname) in fhashes else cls.file_changed class ReadyToProcess: @@ -80,7 +88,9 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore else: self.main_folder = Path(__file__).parent - set_geometry(self, width_rel=0.6, height_rel=0.6, fix_size=False, centralize=True) + set_geometry( + self, width_rel=0.6, height_rel=0.6, fix_size=False, centralize=True + ) self.setWindowTitle("First breaks picking") self.threadpool = QThreadPool() @@ -100,23 +110,33 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore icon_get_filename = self.style().standardIcon(QStyle.SP_DirIcon) # icon_get_filename = QIcon(str(self.main_folder / "icons" / "sgy.png")) - self.button_get_filename = QAction(icon_get_filename, self.TOOLDBAR_OPEN_SGY, self) + self.button_get_filename = QAction( + icon_get_filename, self.TOOLDBAR_OPEN_SGY, self + ) self.button_get_filename.triggered.connect(self.get_filename) self.button_get_filename.setEnabled(True) self.toolbar.addAction(self.button_get_filename) self.toolbar.addSeparator() - icon_visual_settings = self.style().standardIcon(QStyle.SP_FileDialogContentsView) - self.button_settings_processing = QAction(icon_visual_settings, self.TOOLBAR_SETTINGS_AND_PROCESSINGS, self) - self.button_settings_processing.triggered.connect(self.show_settings_processing_window) + icon_visual_settings = self.style().standardIcon( + QStyle.SP_FileDialogContentsView + ) + self.button_settings_processing = QAction( + icon_visual_settings, self.TOOLBAR_SETTINGS_AND_PROCESSINGS, self + ) + self.button_settings_processing.triggered.connect( + self.show_settings_processing_window + ) self.button_settings_processing.setEnabled(False) self.toolbar.addAction(self.button_settings_processing) self.need_processing_region = True icon_processing_show = self.style().standardIcon(QStyle.SP_FileDialogListView) # icon_export = QIcon(str(self.main_folder / "icons" / "export.png")) - self.button_processing_show = QAction(icon_processing_show, self.TOOLBAR_SHOW_GRID, self) + self.button_processing_show = QAction( + icon_processing_show, self.TOOLBAR_SHOW_GRID, self + ) self.button_processing_show.triggered.connect(self.processing_region_changed) self.button_processing_show.setChecked(self.need_processing_region) self.button_processing_show.setEnabled(True) @@ -129,7 +149,9 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore icon_picks_manager = self.style().standardIcon(QStyle.SP_FileDialogDetailedView) # icon_export = QIcon(str(self.main_folder / "icons" / "export.png")) - self.button_picks_manager = QAction(icon_picks_manager, self.TOOLBAR_PICKS_MANAGER, self) + self.button_picks_manager = QAction( + icon_picks_manager, self.TOOLBAR_PICKS_MANAGER, self + ) self.button_picks_manager.triggered.connect(self.show_picks_manager) self.button_picks_manager.setEnabled(False) self.toolbar.addAction(self.button_picks_manager) @@ -165,8 +187,12 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore **self.plotseis_settings.model_dump(), ) self.settings_processing_widget.hide() - self.settings_processing_widget.export_plotseis_settings_signal.connect(self.update_plotseis_settings) - self.settings_processing_widget.export_picking_settings_signal.connect(self.pick_fb) + self.settings_processing_widget.export_plotseis_settings_signal.connect( + self.update_plotseis_settings + ) + self.settings_processing_widget.export_picking_settings_signal.connect( + self.pick_fb + ) # nn manager self.nn_manager = NNManager( @@ -176,15 +202,21 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore interrupt_on=self.settings_processing_widget.interrupt_signal, ) self.nn_manager.picking_finished_signal.connect(self.on_picking_finished) - self.nn_manager.picking_not_started_error_signal.connect(self.on_picking_not_started_error) + self.nn_manager.picking_not_started_error_signal.connect( + self.on_picking_not_started_error + ) # picks manager self.picks_manager = PicksManager() self.picks_manager.picks_updated_signal.connect(self.update_plot) self.picks_manager.hide() - self.graph.picks_manual_edited_signal.connect(self.picks_manager.update_picks_from_external) - self.graph.about_to_change_nn_picks_signal.connect(self.picks_manager.duplicate_active_created_by_nn_picks) + self.graph.picks_manual_edited_signal.connect( + self.picks_manager.update_picks_from_external + ) + self.graph.about_to_change_nn_picks_signal.connect( + self.picks_manager.duplicate_active_created_by_nn_picks + ) self.is_toggled_picks_from_file = False # placeholders @@ -195,7 +227,7 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore self.settings: Optional[Dict[str, Any]] = None self.last_folder: Optional[Union[str, Path]] = None self.picks_from_file_in_ms: Optional[Tuple[Union[int, float], ...]] = None - self.picker_hash = MODEL_ONNX_HASH + self.picker_hashes = MODEL_ONNX_HASHES if show: self.show() @@ -221,6 +253,11 @@ def on_picking_finished(self, result: Task) -> None: if result.success: self.picks_manager.add_nn_picks(result.picks) + refined_picks = result.picks.create_duplicate() + refiner = MinimalPhaseRefiner() + refined_picks = refiner.refine(self.sgy, refined_picks) + self.picks_manager.add_picks(refined_picks, "Refined picks") + self.update_plot(refresh_view=False) self.run_processing_region() else: @@ -255,7 +292,10 @@ def run_processing_region(self) -> None: def show_processing_region(self) -> None: for picks in self.picks_manager.picks_mapping.values(): if picks.created_by_nn and picks.active: - tps, max_time = picks.picking_parameters.traces_per_gather, picks.picking_parameters.maximum_time + tps, max_time = ( + picks.picking_parameters.traces_per_gather, + picks.picking_parameters.maximum_time, + ) self.graph.plot_processing_region(tps, max_time) break @@ -268,7 +308,9 @@ def update_plotseis_settings(self, new_settings: PlotseisSettings) -> None: self.update_plot(False) def update_plot(self, refresh_view: bool = False) -> None: - self.graph.plotseis(self.sgy, refresh_view=refresh_view, **self.plotseis_settings.model_dump()) + self.graph.plotseis( + self.sgy, refresh_view=refresh_view, **self.plotseis_settings.model_dump() + ) self.show_processing_region() self.graph.remove_picks() @@ -290,11 +332,17 @@ def load_nn(self, filename: Optional[Union[str, Path]] = None) -> None: if not filename: options = QFileDialog.Options() filename, _ = QFileDialog.getOpenFileName( - self, "Select file with NN weights", directory=last_folder_manager.get_last_folder(), options=options + self, + "Select file with NN weights", + directory=last_folder_manager.get_last_folder(), + options=options, ) if filename: - if FileState.get_file_state(filename, self.picker_hash) == FileState.valid_file: + if ( + FileState.get_file_state(filename, self.picker_hashes) + == FileState.valid_file + ): self.nn_manager.init_net(weights=filename) self.button_load_nn.setEnabled(False) self.ready_to_process.model_loaded = True @@ -350,7 +398,9 @@ def get_filename(self, filename: Optional[Union[str, Path]] = None) -> None: last_folder_manager.set_last_folder(filename) except Exception as e: - window_err = MessageBox(self, title=e.__class__.__name__, message=str(e)) + window_err = MessageBox( + self, title=e.__class__.__name__, message=str(e) + ) window_err.exec_() def show_picks_manager(self) -> None: @@ -377,10 +427,13 @@ def run_app() -> None: def fetch_data_and_run_app() -> None: + from first_breaks.const import PROJECT_ROOT + download_model_onnx(MODEL_ONNX_PATH) download_demo_sgy(DEMO_SGY_PATH) app, window = create_app() - window.load_nn(MODEL_ONNX_PATH) + # window.load_nn(MODEL_ONNX_PATH) + window.load_nn(PROJECT_ROOT / "fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx") window.get_filename(DEMO_SGY_PATH) app.exec_() diff --git a/first_breaks/desktop/picks_manager_widget.py b/first_breaks/desktop/picks_manager_widget.py index 7a9f9dc..1cc2014 100644 --- a/first_breaks/desktop/picks_manager_widget.py +++ b/first_breaks/desktop/picks_manager_widget.py @@ -57,7 +57,9 @@ def __init__(self) -> None: layout.addLayout(input_layout) - self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self) + self.button_box = QDialogButtonBox( + QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self + ) self.button_box.accepted.connect(self.accept) self.button_box.rejected.connect(self.reject) @@ -91,7 +93,9 @@ def __init__(self) -> None: layout.addWidget(self.combo_box) - self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self) + self.button_box = QDialogButtonBox( + QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self + ) self.button_box.accepted.connect(self.accept) self.button_box.rejected.connect(self.reject) @@ -119,7 +123,9 @@ def __init__(self, text: str = "", color: QColor = QColor(255, 255, 255)) -> Non self.color_display = QLabel(self) # Using QLabel to display color self.color_display.setFixedSize(20, 20) # Fixed size for color display - self.color_display.setStyleSheet(f"background-color: {color.name()}; border: 1px solid black;") + self.color_display.setStyleSheet( + f"background-color: {color.name()}; border: 1px solid black;" + ) self.currentColor = color # Store the current color layout = QHBoxLayout() @@ -147,7 +153,9 @@ def on_radiobutton_clicked(self, checked: bool) -> None: def edit_color(self, event: QEvent) -> None: new_color = QColorDialog.getColor(self.currentColor, self) if new_color.isValid(): - self.color_display.setStyleSheet(f"background-color: {new_color.name()}; border: 1px solid black;") + self.color_display.setStyleSheet( + f"background-color: {new_color.name()}; border: 1px solid black;" + ) self.currentColor = new_color self.color_changed_signal.emit(new_color) # Emit signal with the new color @@ -167,7 +175,9 @@ def __init__( txt_exporter_kwargs = txt_exporter_kwargs or {} json_exporter_kwargs = json_exporter_kwargs or {} - set_geometry(self, width_rel=0.3, height_rel=0.3, centralize=True, fix_size=False) + set_geometry( + self, width_rel=0.3, height_rel=0.3, centralize=True, fix_size=False + ) self.picks_item_widget = picks_item_widget self.picks_mapping = picks_mapping @@ -206,7 +216,9 @@ def __init__( self.tab_all.addTab(self.tab_export, "Export") layout.addWidget(self.tab_all) - self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self) + self.button_box = QDialogButtonBox( + QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self + ) self.button_box.accepted.connect(self.accept) self.button_box.rejected.connect(self.reject) @@ -290,7 +302,9 @@ def __init__(self) -> None: self.list_widget = QListWidget(self) self.list_widget.itemDoubleClicked.connect(self.open_properties) - self.list_widget.itemSelectionChanged.connect(self.update_properties_button_state) + self.list_widget.itemSelectionChanged.connect( + self.update_properties_button_state + ) # Enable multi-selection mode self.list_widget.setSelectionMode(QListWidget.ExtendedSelection) @@ -303,8 +317,12 @@ def __init__(self) -> None: self.remove_button = QPushButton("-", self) self.remove_button.clicked.connect(self.remove_items) # renamed for clarity # self.properties_button = QPushButton("\u2699", self) # Unicode character for gear - self.properties_button = QPushButton("\U0001F4BE", self) # "\U0001F4BE" or "\U0001F5AB" - self.properties_button.setFont(self.font()) # to increase the size of the button a bit + self.properties_button = QPushButton( + "\U0001F4BE", self + ) # "\U0001F4BE" or "\U0001F5AB" + self.properties_button.setFont( + self.font() + ) # to increase the size of the button a bit self.properties_button.clicked.connect(self.open_properties) button_layout.addWidget(self.add_button) button_layout.addWidget(self.remove_button) @@ -401,23 +419,34 @@ def add_constant_values_pick(self) -> Optional[PicksItemWidget]: created_by_nn=False, ) self.items_counter.constant_values += 1 - return self.add_picks(picks, f"Manual {self.items_counter.constant_values}") + return self.add_picks( + picks, f"Manual {self.items_counter.constant_values}" + ) return None else: return None - def add_duplicate_pick(self, selected_picks_item: Optional[QListWidgetItem] = None) -> PicksItemWidget: - selected_picks_item = selected_picks_item or self.list_widget.itemWidget(self.list_widget.selectedItems()[0]) + def add_duplicate_pick( + self, selected_picks_item: Optional[QListWidgetItem] = None + ) -> PicksItemWidget: + selected_picks_item = selected_picks_item or self.list_widget.itemWidget( + self.list_widget.selectedItems()[0] + ) picks = self.picks_mapping[selected_picks_item] duplicated_picks = picks.create_duplicate(keep_color=False) self.items_counter.duplicated += 1 - return self.add_picks(duplicated_picks, f"Duplicated from '{selected_picks_item.get_name()}'") + return self.add_picks( + duplicated_picks, f"Duplicated from '{selected_picks_item.get_name()}'" + ) def add_aggregate_pick(self) -> Optional[PicksItemWidget]: self.items_counter.aggregated += 1 - selected_picks_items = [self.list_widget.itemWidget(item) for item in self.list_widget.selectedItems()] + selected_picks_items = [ + self.list_widget.itemWidget(item) + for item in self.list_widget.selectedItems() + ] dialog = AggregationDialog() self._aggregation_widget = dialog @@ -445,7 +474,9 @@ def add_aggregate_pick(self) -> Optional[PicksItemWidget]: return None def add_from_headers_pick(self) -> Optional[PicksItemWidget]: - dialog = QDialogByteEncodeUnit(byte_position=1, first_byte=FIRST_BYTE, encoding="I", picks_unit="mcs") + dialog = QDialogByteEncodeUnit( + byte_position=1, first_byte=FIRST_BYTE, encoding="I", picks_unit="mcs" + ) self._load_from_headers_widget = dialog if dialog.exec_() == QDialog.Accepted: @@ -481,7 +512,9 @@ def add_picks(self, picks: Picks, name: str) -> PicksItemWidget: item.setSizeHint(picks_item_widget.sizeHint()) picks_item_widget.color_changed_signal.connect( - lambda color, widget=picks_item_widget: self.update_picks_color(widget, color) + lambda color, widget=picks_item_widget: self.update_picks_color( + widget, color + ) ) picks_item_widget.checkbox.clicked.connect(self.picks_updated_signal) picks_item_widget.checkbox.setChecked(True) @@ -551,7 +584,9 @@ def update_properties_button_state(self) -> None: selected_items = self.list_widget.selectedItems() self.properties_button.setEnabled(len(selected_items) == 1) - def update_picks_color(self, picks_item_widget: PicksItemWidget, color: QColor) -> None: + def update_picks_color( + self, picks_item_widget: PicksItemWidget, color: QColor + ) -> None: pick = self.picks_mapping.get(picks_item_widget) if pick: pick.color = (color.red(), color.green(), color.blue()) diff --git a/first_breaks/picking/picker_onnx.py b/first_breaks/picking/picker_onnx.py index ad13ff8..2edf59e 100644 --- a/first_breaks/picking/picker_onnx.py +++ b/first_breaks/picking/picker_onnx.py @@ -23,7 +23,10 @@ class IteratorOfTask: def __init__(self, task: Task): self.task = task - self.idx2gather_ids = {idx: gather_ids for idx, gather_ids in enumerate(self.task.get_gathers_ids())} + self.idx2gather_ids = { + idx: gather_ids + for idx, gather_ids in enumerate(self.task.get_gathers_ids()) + } if self.task.normalize == "gather" and len(self.idx2gather_ids) > 1: raise AssertionError( "'gather' normalization can't be used for picking when number of gathers > 1. " @@ -36,7 +39,10 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: gather_ids = self.idx2gather_ids[idx] amplitudes = np.array( - [-1 if idx in self.task.traces_to_inverse else 1 for idx in range(len(gather_ids))], + [ + -1 if idx in self.task.traces_to_inverse else 1 + for idx in range(len(gather_ids)) + ], dtype=np.float32, ) gather = self.task.sgy.read_traces_by_ids(gather_ids) @@ -56,7 +62,9 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: return {self.gather_key: gather, self.gather_ids_key: np.array(gather_ids)} - def get_batch_generator(self, batch_size: int = 1) -> Generator[Dict[str, np.ndarray], None, None]: + def get_batch_generator( + self, batch_size: int = 1 + ) -> Generator[Dict[str, np.ndarray], None, None]: for ids in chunk_iterable(range(len(self)), batch_size): gather_batch = [] gather_ids_batch = [] @@ -124,28 +132,37 @@ def change_settings( # type: ignore return self - def pick_batch_of_gathers(self, gather: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def pick_batch_of_gathers( + self, gather: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: assert gather.ndim == 4 assert all(dim > 0 for dim in gather.shape) - outputs = self.model.run(None, {"input": gather}) - return outputs[0], outputs[1] + outputs = self.model.run(["picks", "confs", "heatmap"], {"input": gather}) + return outputs[0], outputs[1], outputs[2] def process_task(self, task: Task) -> Task: task_picks_in_sample = np.zeros(task.sgy.num_traces) task_confidence = np.zeros(task.sgy.num_traces) + task_heatmap = np.zeros((task.sgy.num_samples, task.sgy.num_traces)) task_iterator = IteratorOfTask(task) counter_step_finished = 0 self.callback_processing_started(len(task_iterator)) - for idx, batch in enumerate(task_iterator.get_batch_generator(batch_size=self.batch_size)): + for idx, batch in enumerate( + task_iterator.get_batch_generator(batch_size=self.batch_size) + ): self.interrupt_if_need() data = batch["gather"].astype(np.float32) - picks, confidence = self.pick_batch_of_gathers(data) + picks, confidence, heatmap = self.pick_batch_of_gathers(data) indices = batch["gather_ids"] - task_picks_in_sample[indices.flatten()] = picks.flatten() - task_confidence[indices.flatten()] = confidence.flatten() + + task_picks_in_sample[indices.flatten()] = picks.flatten() # (B, W) -> (BxW) + task_confidence[indices.flatten()] = confidence.flatten() # (B, W) -> (BxW) + heatmap = heatmap.transpose(1, 0, 2) # (B, H, W) -> (H, B, W) + heatmap = heatmap.reshape(heatmap.shape[0], -1) # (H, B, W) -> (H, BxW) + task_heatmap[: len(heatmap), indices.flatten()] = heatmap counter_step_finished += len(data) self.callback_step_finished(counter_step_finished) @@ -154,6 +171,7 @@ def process_task(self, task: Task) -> Task: picks = Picks( values=task_picks_in_sample.astype(int).tolist(), + heatmap=task_heatmap, unit="sample", dt_mcs=task.sgy.dt_mcs, confidence=task_confidence.tolist(), diff --git a/first_breaks/picking/picks.py b/first_breaks/picking/picks.py index b157b1e..222b92f 100644 --- a/first_breaks/picking/picks.py +++ b/first_breaks/picking/picks.py @@ -43,6 +43,7 @@ class Picks(DefaultModel): dt_mcs: Optional[float] = None confidence: Optional[Union[np.ndarray, List[Union[int, float]]]] = None + heatmap: Optional[np.ndarray] = None created_by_nn: Optional[bool] = None created_manually: Optional[bool] = None modified_manually: Optional[bool] = None @@ -68,7 +69,10 @@ def _sync_units_converter_with_dt_mcs(self) -> "Picks": self.model_config["validate_assignment"] = False if self.dt_mcs is not None: - if self._units_converter is None or self._units_converter.sgy_mcs != self.dt_mcs: + if ( + self._units_converter is None + or self._units_converter.sgy_mcs != self.dt_mcs + ): self._units_converter = UnitsConverter(sgy_mcs=self.dt_mcs) else: self._units_converter = None @@ -152,10 +156,12 @@ def from_samples(self, values: TValues) -> None: def create_duplicate(self, keep_color: bool = False) -> "Picks": values = self.values.copy() confidence = self.confidence.copy() if self.confidence is not None else None + heatmap = self.heatmap.copy() if self.heatmap is not None else None return Picks( values=values, confidence=confidence, + heatmap=heatmap, dt_mcs=self.dt_mcs, unit=self.unit, created_manually=True, diff --git a/first_breaks/picking/refiner.py b/first_breaks/picking/refiner.py new file mode 100644 index 0000000..f4acbe3 --- /dev/null +++ b/first_breaks/picking/refiner.py @@ -0,0 +1,324 @@ +from pprint import pprint +from typing import Tuple + +import numpy as np + +from first_breaks.const import PROJECT_ROOT +from first_breaks.picking.picks import Picks +from first_breaks.sgy.reader import SGY +from first_breaks.utils.debug import Performance +from first_breaks.utils.filtering import apply_savgol_filter + + +class Refiner: + def refine(self, sgy: SGY, picks: Picks): + raise NotImplementedError + + +def find_extrema_mask(data: np.ndarray, neighbor_range: int = 1) -> np.ndarray: + assert data.ndim == 2 + + maxima_mask = np.ones(data.shape, dtype=bool) + minima_mask = np.ones(data.shape, dtype=bool) + ids = np.arange(len(data)) + + for shift in range(1, neighbor_range + 1): + shifted_data_left = data.take(ids + shift, mode="clip", axis=0) + shifted_data_right = data.take(ids - shift, mode="clip", axis=0) + + maxima_mask &= (data > shifted_data_left) & (data > shifted_data_right) + minima_mask &= (data < shifted_data_left) & (data < shifted_data_right) + + extrema_mask = maxima_mask | minima_mask + + return extrema_mask + + +def calc_intersection( + data: np.ndarray, data_derivative: np.ndarray, tangent_points: np.ndarray +) -> np.ndarray: + slope = data_derivative[tangent_points] + intercept = data[tangent_points] - slope * tangent_points + return -intercept / slope + + +def calc_intersection_vectorized( + data: np.ndarray, data_derivative: np.ndarray, extrema_mask: np.ndarray +): + assert all(arr.ndim == 2 for arr in [data, data_derivative, extrema_mask]) + assert extrema_mask.dtype == np.bool_ + + extrema_indices = np.where(extrema_mask) + row_indices, col_indices = extrema_indices + + sorted_indices = np.lexsort((row_indices, col_indices)) + sorted_row_indices = row_indices[sorted_indices] + sorted_col_indices = col_indices[sorted_indices] + + slope = data_derivative[sorted_row_indices, sorted_col_indices] + intercept = ( + data[sorted_row_indices, sorted_col_indices] - slope * sorted_row_indices + ) + intersection = -intercept / slope + + to_keep = ( + (intersection >= 0) + & (intersection < (len(data) - 1)) + & (intersection != np.inf) + ) + + intersection = intersection[to_keep] + sorted_col_indices = sorted_col_indices[to_keep] + + unique_cols, start_indices = np.unique(sorted_col_indices, return_index=True) + intersections = dict(zip(unique_cols, np.split(intersection, start_indices[1:]))) + + return intersections + + +def get_band_mask( + data: np.ndarray, band_ids: np.ndarray, width_before: int, width_after: int +) -> Tuple[np.ndarray, np.ndarray]: + num_rows, num_cols = data.shape + row_indices = np.arange(-width_before, width_after + 1).reshape(-1, 1) + band_ids + row_indices_clipped = np.clip(row_indices, 0, num_rows - 1) + return row_indices_clipped, np.arange(num_cols) + + +def refine_picks( + raw_picks: np.ndarray, + probability_heatmap: np.ndarray, + traces2intersections, + minimum_probability_to_refine: float = 0.9, +) -> np.ndarray: + refined_picks = raw_picks.copy() + + for trace, intersections in traces2intersections.items(): + intersections_int = intersections.astype(int) + prob = probability_heatmap[intersections_int, trace] + # print( + # trace, + # intersections_int, + # prob, + # [raw_picks[trace], probability_heatmap[raw_picks[trace], trace]], + # ) + best_candidate = np.argmax(prob) + if prob[best_candidate] > minimum_probability_to_refine: + refined_picks[trace] = intersections_int[best_candidate] + + return refined_picks + + +class MinimalPhaseRefiner(Refiner): + def __init__( + self, + analyse_window_before: int = 5, + analyse_window_after: int = 15, + smooth_window: int = 11, + smooth_polyorder: int = 3, + extrema_window: int = 3, + min_probability_to_refine: float = 0.9, + ): + self.analyse_window_before = analyse_window_before + self.analyse_window_after = analyse_window_after + self.smooth_window = smooth_window + self.smooth_polyorder = smooth_polyorder + self.extrema_window = extrema_window + self.min_probability_to_refine = min_probability_to_refine + + def refine(self, sgy: SGY, picks: Picks) -> Picks: + assert picks.heatmap is not None + # intersection point doesn't depend on scale of data, so we don't process raw data + raw = sgy.read() + picks_in_samples = picks.picks_in_samples.copy() + + filtered = apply_savgol_filter( + data=raw, + polyorder=self.smooth_polyorder, + window_length=self.smooth_window, + deriv=0, + ) + first_derivateive = apply_savgol_filter( + data=raw, + polyorder=self.smooth_polyorder, + window_length=self.smooth_window, + deriv=1, + ) + + band_mask = get_band_mask( + data=raw, + band_ids=picks_in_samples, + width_before=self.analyse_window_before, + width_after=self.analyse_window_after, + ) + + extrema = find_extrema_mask( + data=first_derivateive[band_mask], neighbor_range=self.extrema_window + ) + + tr2intersections = calc_intersection_vectorized( + data=filtered[band_mask], + data_derivative=-first_derivateive[band_mask], + extrema_mask=extrema, + ) + + # previous intersections obtained based on band data. We need to map band intersections to initail coordintates + tr2intersections = { + tr: band_mask[0][inter.round().astype(int), tr] + for tr, inter in tr2intersections.items() + } + + refined_picks = refine_picks( + raw_picks=picks_in_samples, + probability_heatmap=picks.heatmap, + traces2intersections=tr2intersections, + minimum_probability_to_refine=self.min_probability_to_refine, + ) + + picks.from_samples(refined_picks) + + return picks + + +if __name__ == "__main__": + sgy = SGY(PROJECT_ROOT / "with_picks.sgy") + heatmap = np.load(PROJECT_ROOT / "heatmap.npy") + src_picks = Picks( + values=sgy.read_custom_trace_header(236, "i"), + unit="mcs", + dt_mcs=sgy.dt_mcs, + heatmap=heatmap, + ) + + new_picks = src_picks.create_duplicate() + + refiner = MinimalPhaseRefiner() + with Performance(): + new_picks = refiner.refine(sgy, new_picks) + + print(src_picks.picks_in_samples) + print(new_picks.picks_in_samples) + + # num_tr = 20 + # num_samples = 20 + # window_smooth = 11 + # order = 3 + # window_extrema = 3 + # window_analyse_before = 5 + # window_analyse_after = 5 + # min_probability_to_refine = 0.9 + # + # + # raw = np.random.uniform(size=(num_samples, num_tr)) + # picks = np.random.randint(0, num_samples, size=num_tr).astype(int) + # heatmap = np.random.randint(1, 3, size=raw.shape) + # + # with Performance(): + # filtered = apply_savgol_filter( + # data=raw, polyorder=order, window_length=window_smooth, deriv=0 + # ) + # first_derivateive = apply_savgol_filter( + # data=raw, polyorder=order, window_length=window_smooth, deriv=1 + # ) + # + # band_mask = get_band_mask( + # data=raw, + # band_ids=picks, + # width_before=window_analyse_before, + # width_after=window_analyse_after, + # ) + # + # extrema = find_extrema_mask( + # data=first_derivateive[band_mask], neighbor_range=window_extrema + # ) + # + # tr2intersections = calc_intersection_vectorized( + # data=filtered[band_mask], + # data_derivative=-first_derivateive[band_mask], + # extrema_mask=extrema, + # ) + # # pprint(band_mask[0]) + # # pprint(picks) + # # pprint(tr2intersections) + # # band_start = band_mask[0][0, :] + # # pprint(band_start) + # # tr2intersections = { + # # tr: inter + band_start[tr] for tr, inter in tr2intersections.items() + # # } + # + # # pprint(band_mask[0]) + # # pprint(tr2intersections) + # + # # previous intersections obtained based on band data. We need to map band intersections to initail coordintates + # tr2intersections = { + # tr: band_mask[0][inter.round().astype(int), tr] + # for tr, inter in tr2intersections.items() + # } + # + # refined_picks = refine_picks( + # raw_picks=picks, + # probability_heatmap=heatmap, + # traces2intersections=tr2intersections, + # minimum_probability_to_refine=min_probability_to_refine, + # ) + + # d = np.random.uniform(size=(20, 20)) + # der = np.random.uniform(size=(20, 20)) + # + # + # picks = np.array([5] * 20) + # + # # picks[3] = 1 + # + # d[picks - 3 : picks + 3, :] + # + # + # d[:, 1:10] = 100 + # # tang = [] + # + # # aa = np.random.randint(0, 2, size=(3, 3)).astype(bool) + # # print(aa) + # # print(aa.nonzero()) + # + # with Performance(): + # extrema = find_extrema_mask(d) + # + # + # # print(res) + # # print(np.where(res)) + # + # with Performance(): + # tr2intersection = {} + # + # res = np.where(extrema) + # + # for i in np.unique(res[1]): + # extrema_tr = res[0][res[1] == i] + # tr2intersection[i] = calc_intersection(d[:, i], der[:, i], extrema_tr) + # + # + # print(extrema.shape) + # + # with Performance(): + # v = calc_intersection_vectorized(d, der, extrema) + # + # + # # print(len(v), len(tr2intersection)) + # # + # # print(v[0]) + # # print(tr2intersection[0]) + # + # + # assert all(np.allclose(tr2intersection[i], v[i]) for i in tr2intersection.keys()) + # + # print(tr2intersection[5]) + + # d = np.arange(10)[:, None] + # d = np.tile(d, (1, 5)) + # picks = np.array([1, 9, 3, 4, 5]) + # + # band_mask = get_band_mask(d, picks, 3, 2) + # + # print(d) + # print(band_mask) + # print(d[band_mask]) diff --git a/first_breaks/utils/filtering.py b/first_breaks/utils/filtering.py new file mode 100644 index 0000000..cd22c50 --- /dev/null +++ b/first_breaks/utils/filtering.py @@ -0,0 +1,40 @@ +import numpy as np + + +def savgol_coeffs(window_length: int, polyorder: int, deriv: int = 0) -> np.ndarray: + assert window_length % 2 == 1, "The value must be odd" + assert polyorder % 2 == 1, "The value must be odd" + assert polyorder < window_length + + half_window = (window_length - 1) // 2 + a = np.zeros((window_length, polyorder + 1)) + for i in range(window_length): + for j in range(polyorder + 1): + a[i, j] = (i - half_window) ** j + atr_a = np.dot(a.T, a) + atr = np.linalg.pinv(atr_a) + b = np.dot(atr, a.T) + coeffs = b[deriv] + return coeffs + + +def apply_savgol_filter( + data: np.ndarray, window_length: int, polyorder: int, deriv: int = 0 +) -> np.ndarray: + assert 1 <= data.ndim <= 2 + coeffs = savgol_coeffs(window_length, polyorder, deriv) + + half_window = (window_length - 1) // 2 + pad_mode = "reflect" + + if data.ndim == 1: + padding = (half_window, half_window) + else: + padding = ((half_window, half_window), (0, 0)) + + padded_data = np.pad(data, padding, mode=pad_mode) + + filtered_data = np.apply_along_axis( + lambda m: np.convolve(m, coeffs, mode="valid"), axis=0, arr=padded_data + ) + return filtered_data diff --git a/first_breaks/utils/utils.py b/first_breaks/utils/utils.py index 1f1cb60..d00f402 100644 --- a/first_breaks/utils/utils.py +++ b/first_breaks/utils/utils.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd import requests +from tqdm.auto import tqdm from first_breaks.const import ( DEMO_SGY_HASH, @@ -34,7 +35,9 @@ def chunk_iterable(it: Iterable[Any], size: int) -> List[Tuple[Any, ...]]: return list(iter(lambda: tuple(islice(it, size)), ())) -def get_io(source: Union[Path, str, bytes], mode: str = "r") -> Union[io.BytesIO, io.FileIO]: +def get_io( + source: Union[Path, str, bytes], mode: str = "r" +) -> Union[io.BytesIO, io.FileIO]: if isinstance(source, (Path, str)): source = Path(source).resolve() if "r" in mode: @@ -57,17 +60,51 @@ def calc_hash(source: Union[Path, str, bytes, io.BytesIO, io.FileIO]) -> str: return hash_md5.hexdigest() -def download_by_url(url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT) -> Optional[bytes]: - response = requests.get(url, timeout=timeout) - if response.status_code != 200: - response.raise_for_status() - return None - else: +# def download_by_url( +# url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT +# ) -> Optional[bytes]: +# response = requests.get(url, timeout=timeout, stream=True) +# if response.status_code != 200: +# response.raise_for_status() +# return None +# else: +# if fname: +# Path(fname).parent.mkdir(exist_ok=True, parents=True) +# with open(fname, "wb+") as f: +# f.write(response.content) +# return response.content + + +def download_by_url( + url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT +) -> bytes: + response = requests.get(url, stream=True, timeout=timeout) + response.raise_for_status() + total_size = int(response.headers.get("content-length", 0)) + block_size = 1024 + buffer = io.BytesIO() + + with tqdm( + desc=f"Downloading {url}", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in response.iter_content(block_size): + buffer.write(data) + bar.update(len(data)) + + buffer.seek(0) + content = buffer.getvalue() + if fname: Path(fname).parent.mkdir(exist_ok=True, parents=True) with open(fname, "wb+") as f: - f.write(response.content) - return response.content + f.write(content) + bar.set_description(f"File {url} saved to '{Path(fname).resolve()}'") + + return content def download_and_validate_file( @@ -77,23 +114,31 @@ def download_and_validate_file( download_by_url(url=url, fname=fname, timeout=timeout) md5_last = calc_hash(fname) if md5_last != md5: - raise InvalidHash(f"Hash for file {Path(fname).resolve()} in invalid. Got {md5_last}, expected {md5}") + raise InvalidHash( + f"Hash for file {Path(fname).resolve()} in invalid. Got {md5_last}, expected {md5}" + ) return fname def download_demo_sgy( - fname: Union[str, Path] = DEMO_SGY_PATH, url: str = DEMO_SGY_URL, md5: str = DEMO_SGY_HASH + fname: Union[str, Path] = DEMO_SGY_PATH, + url: str = DEMO_SGY_URL, + md5: str = DEMO_SGY_HASH, ) -> Union[str, Path]: return download_and_validate_file(fname=fname, url=url, md5=md5) def download_model_onnx( - fname: Union[str, Path] = MODEL_ONNX_PATH, url: str = MODEL_ONNX_URL, md5: str = MODEL_ONNX_HASH + fname: Union[str, Path] = MODEL_ONNX_PATH, + url: str = MODEL_ONNX_URL, + md5: str = MODEL_ONNX_HASH, ) -> Union[str, Path]: return download_and_validate_file(fname=fname, url=url, md5=md5) -def multiply_iterable_by(sample: TTimeType, multiplier: float, cast_to: Optional[Any] = None) -> TTimeType: +def multiply_iterable_by( + sample: TTimeType, multiplier: float, cast_to: Optional[Any] = None +) -> TTimeType: if isinstance(sample, (int, float, str)): result = sample * multiplier # type: ignore return cast_to(result) if cast_to is not None else result @@ -116,9 +161,15 @@ def __init__( sgy_ms: Optional[Union[int, float]] = None, ): if args: - raise ValueError("Specify explicitly either `sgy_mcs`or `sgy_ms` as keyword argument") - if (sgy_mcs is None and sgy_ms is None) or (sgy_mcs is not None and sgy_ms is not None): - raise RuntimeError("One and only one of `sgy_mcs` or `sgy_ms` must be specified") + raise ValueError( + "Specify explicitly either `sgy_mcs`or `sgy_ms` as keyword argument" + ) + if (sgy_mcs is None and sgy_ms is None) or ( + sgy_mcs is not None and sgy_ms is not None + ): + raise RuntimeError( + "One and only one of `sgy_mcs` or `sgy_ms` must be specified" + ) elif sgy_mcs is not None: self.sgy_mcs = sgy_mcs self.sgy_ms = self.mcs2ms(sgy_mcs) # type: ignore @@ -150,7 +201,11 @@ def index2mcs(self, sample: TTimeType, cast_to: Any = int) -> TTimeType: def remove_unused_kwargs(kwargs: Dict[str, Any], constructor: Any) -> Dict[str, Any]: - return {k: v for k, v in kwargs.items() if k in inspect.signature(constructor).parameters} + return { + k: v + for k, v in kwargs.items() + if k in inspect.signature(constructor).parameters + } def _color_generator() -> Generator[List[int], None, None]: diff --git a/first_breaks/utils/visualizations.py b/first_breaks/utils/visualizations.py index 60afcc4..3fa6c9a 100644 --- a/first_breaks/utils/visualizations.py +++ b/first_breaks/utils/visualizations.py @@ -5,6 +5,8 @@ import numpy as np from matplotlib.patches import Polygon +from first_breaks.picking.utils import preprocess_gather + def plotseis( data: np.ndarray, @@ -26,24 +28,26 @@ def plotseis( num_time, num_trace = np.shape(data) - if normalizing == "indiv": - norm_factor = np.mean(np.abs(data), axis=0) - norm_factor[np.abs(norm_factor) < 1e-9 * np.max(np.abs(norm_factor))] = 1 - elif normalizing == "entire": - norm_factor = np.tile(np.mean(np.abs(data)), (1, num_trace)) - elif np.size(normalizing) == 1 and normalizing is not None: - norm_factor = np.tile(normalizing, (1, num_trace)) - elif np.size(normalizing) == num_trace: - norm_factor = np.reshape(normalizing, (1, num_trace)) - elif normalizing is None: - norm_factor = np.ones(data.shape[1]) - else: - raise ValueError('Wrong value of "normalizing"') - - data = data / norm_factor * ampl - - mask_overflow = np.abs(data) > clip - data[mask_overflow] = np.sign(data[mask_overflow]) * clip + # if normalizing == "indiv": + # norm_factor = np.mean(np.abs(data), axis=0) + # norm_factor[np.abs(norm_factor) < 1e-9 * np.max(np.abs(norm_factor))] = 1 + # elif normalizing == "entire": + # norm_factor = np.tile(np.mean(np.abs(data)), (1, num_trace)) + # elif np.size(normalizing) == 1 and normalizing is not None: + # norm_factor = np.tile(normalizing, (1, num_trace)) + # elif np.size(normalizing) == num_trace: + # norm_factor = np.reshape(normalizing, (1, num_trace)) + # elif normalizing is None: + # norm_factor = np.ones(data.shape[1]) + # else: + # raise ValueError('Wrong value of "normalizing"') + # + # data = data / norm_factor * ampl + + # mask_overflow = np.abs(data) > clip + # data[mask_overflow] = np.sign(data[mask_overflow]) * clip + + data = preprocess_gather(data=data, gain=ampl, clip=clip, normalize=normalizing) data_time = np.tile((np.arange(num_time) + 1)[:, np.newaxis], (1, num_trace)) * dt @@ -95,7 +99,9 @@ def plotseis( tail = np.array((k_trace + 1, num_time * dt))[np.newaxis, :] patch_data = np.vstack((head, patch_data, tail)) - polygon = Polygon(patch_data, closed=True, facecolor="black", edgecolor=None) + polygon = Polygon( + patch_data, closed=True, facecolor="black", edgecolor=None + ) ax.add_patch(polygon) if picking is not None: From ea00f8bb675c30e07a03601dc99d47a1b73ce1df Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Mon, 24 Jun 2024 23:08:00 +0200 Subject: [PATCH 2/9] bench --- dev_refi.py | 247 ------------------- first_breaks/benchmark.py | 246 +++++++++++++----- first_breaks/const.py | 13 +- first_breaks/desktop/graph.py | 47 +--- first_breaks/desktop/main_gui.py | 65 ++--- first_breaks/desktop/picks_manager_widget.py | 69 ++---- first_breaks/picking/picker_onnx.py | 22 +- first_breaks/picking/picks.py | 5 +- first_breaks/picking/refiner.py | 27 +- first_breaks/utils/filtering.py | 11 +- first_breaks/utils/utils.py | 36 +-- first_breaks/utils/visualizations.py | 4 +- 12 files changed, 262 insertions(+), 530 deletions(-) delete mode 100644 dev_refi.py diff --git a/dev_refi.py b/dev_refi.py deleted file mode 100644 index 7d74676..0000000 --- a/dev_refi.py +++ /dev/null @@ -1,247 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt - -from first_breaks.exports.export_picks import export_to_sgy -from first_breaks.picking.picker_onnx import PickerONNX -from first_breaks.picking.picks import Picks -from first_breaks.picking.task import Task -from first_breaks.picking.utils import preprocess_gather -from first_breaks.sgy.reader import SGY -from first_breaks.utils.utils import download_demo_sgy, generate_color -from first_breaks.utils.visualizations import plotseis - - -def savgol_coeffs(window_length, polyorder, deriv=0): - half_window = (window_length - 1) // 2 - A = np.zeros((window_length, polyorder + 1)) - for i in range(window_length): - for j in range(polyorder + 1): - A[i, j] = (i - half_window) ** j - ATA = np.dot(A.T, A) - AT = np.linalg.pinv(ATA) - B = np.dot(AT, A.T) - coeffs = B[deriv] - return coeffs - - -def apply_savgol_filter(data, window_length, polyorder, deriv=0): - coeffs = savgol_coeffs(window_length, polyorder, deriv) - - print(coeffs) - half_window = (window_length - 1) // 2 - pad_mode = "reflect" - - padded_data = np.pad(data, ((half_window, half_window), (0, 0)), mode=pad_mode) - - filtered_data = np.apply_along_axis( - lambda m: np.convolve(m, coeffs, mode="valid"), axis=0, arr=padded_data - ) - return filtered_data - - -# # Example seismogram (replace with your data) -# seismogram = np.array([ -# # Your seismic data here -# ]) -# -# # Parameters for Savitzky-Golay filter -# window_length = 11 # Choose an appropriate window length (must be odd) -# polyorder = 3 # Polynomial order -# -# # Apply Savitzky-Golay filter to smooth the data along the N dimension -# smoothed_seismogram = apply_savgol_filter(seismogram, window_length, polyorder) -# -# # Compute the first derivative using Savitzky-Golay filter -# first_derivative_seismogram = apply_savgol_filter(seismogram, window_length, polyorder, deriv=1) -# -# # Example initial picks from neural network (replace with your actual picks) -# initial_picks = np.array([/* your initial picks here */]) -# -# # Refine the picks using the tangent point and intersection method -# refined_picks = [] -# for trace in range(seismogram.shape[0]): -# trace_picks = [] -# for pick in initial_picks[trace]: -# tangent_point, intersection = find_tangent_and_intersection(seismogram[trace], first_derivative_seismogram[trace]) -# trace_picks.append(intersection) -# refined_picks.append(trace_picks) -# -# # Convert refined_picks to a 2D array -# refined_picks = np.array(refined_picks) -# -# # Visualize results for a specific trace -# trace_idx = 0 # Change as needed -# plt.figure(figsize=(12, 6)) -# plt.plot(seismogram[trace_idx], label='Original Seismic Trace') -# plt.plot(smoothed_seismogram[trace_idx], label='Smoothed Seismic Trace', linestyle='dashed') -# plt.plot(first_derivative_seismogram[trace_idx], label='First Derivative', linestyle='dotted') -# plt.scatter(initial_picks[trace_idx], seismogram[trace_idx][initial_picks[trace_idx]], color='red', label='Initial Picks') -# plt.scatter(refined_picks[trace_idx], seismogram[trace_idx][refined_picks[trace_idx].astype(int)], color='green', label='Refined Picks', marker='x') -# plt.legend() -# plt.xlabel('Sample Index') -# plt.ylabel('Amplitude') -# plt.title('Seismic Trace and Picks Refinement') -# plt.show() - - -# sgy = SGY(download_demo_sgy()) -# picker = PickerONNX(model_path="fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx", batch_size=3) -# task = Task(source=sgy, traces_per_gather=12, maximum_time=100) -# task = picker.process_task(task) -# -# plt.imshow(task.picks.heatmap) -# plt.show() -# print(task.picks.picks_in_mcs) -# sgy = SGY("with_picks.sgy") -# picks = Picks( -# values=sgy.read_custom_trace_header(236, "i"), unit="mcs", dt_mcs=sgy.dt_mcs -# ) -# print(picks.picks_in_mcs) -# print(task.picks.heatmap.max(), task.picks.heatmap.min()) -# -# np.save("heatmap.npy", task.picks.heatmap) -# -# -# assert False -# export_to_sgy(sgy, filename="with_picks.sgy", picks=task.picks) - - -def find_candidates(first_derivative, window=None, neighbor_range=1): - """Find all extrema of the first derivative within the window using extended neighborhood checks.""" - if window is None: - start = 0 - end = len(first_derivative) - else: - start, end = window - - segment = first_derivative[start:end, ...] - - # Initialize masks for maxima and minima - maxima_mask = np.ones(segment.shape, dtype=bool) - minima_mask = np.ones(segment.shape, dtype=bool) - ids = np.arange(len(segment)) - - for shift in range(1, neighbor_range + 1): - shifted_segment_left = segment.take(ids + shift, mode="clip", axis=0) - shifted_segment_right = segment.take(ids - shift, mode="clip", axis=0) - - maxima_mask &= (segment > shifted_segment_left) & ( - segment > shifted_segment_right - ) - minima_mask &= (segment < shifted_segment_left) & ( - segment < shifted_segment_right - ) - - extrema_mask = maxima_mask | minima_mask - - extrema = np.where(extrema_mask)[0] + start - return extrema - - -def calculate_intersections(smoothed_trace, first_derivative, cand_points): - """Calculate the intersection of the tangent line at the tangent point with the time axis.""" - slope = first_derivative[cand_points] - intercept = smoothed_trace[cand_points] - slope * cand_points - return -intercept / slope - - # if slope != 0: - # intersection = -intercept / slope - # else: - # intersection = cand_points - # - # return intersection - - # slope = first_derivative[cand_points] - # intercept = smoothed_trace[cand_points] - slope * cand_points - # return intercept - - -sgy = SGY("with_picks.sgy") -picks = Picks( - values=sgy.read_custom_trace_header(236, "i"), unit="mcs", dt_mcs=sgy.dt_mcs -) - - -# plotseis(sgy.read(max_sample=500), picking=picks.picks_in_samples, normalizing="trace") - -data = -preprocess_gather(sgy.read(), clip=3, gain=1, normalize="trace") -heatmap = np.load("heatmap.npy") - -# trace_idx = 40 -# start_ms = 25 -# end_ms = 42 - -# trace_idx = 34 -# start_ms = 25 -# end_ms = 42 - -# trace_idx = 10 -# start_ms = 40 -# end_ms = 80 - -trace_idx = 93 -start_ms = 40 -end_ms = 60 - -start_sample = sgy.units_converter.ms2index(start_ms) -end_sample = sgy.units_converter.ms2index(end_ms) -# start_sample = 60 -# end_sample = 170 -sub = data[start_sample:end_sample, trace_idx] -# -plt.plot(sub, color="b") -plt.plot( - [ - picks.picks_in_samples[trace_idx] - start_sample, - picks.picks_in_samples[trace_idx] - start_sample, - ], - [min(sub), max(sub)], -) -# plt.show() - -window = 11 -order = 3 - -smoothed = apply_savgol_filter( - sub[:, None], window_length=window, polyorder=order, deriv=0 -) -plt.plot(smoothed[:, 0], color="k") - -first_deriv = apply_savgol_filter( - sub[:, None], window_length=window, polyorder=order, deriv=1 -) -plt.plot((2 * first_deriv[:, 0]), color="r") -cand = find_candidates(first_deriv[:, 0], neighbor_range=3) - -refined = calculate_intersection(smoothed[:, 0], -first_deriv[:, 0], cand) - -mask_refined = (refined < len(sub)) & (refined > 0) -refined = refined[mask_refined] -cand = cand[mask_refined] -print(cand.astype(int)) -print(refined.astype(int)) -print( - (100 * heatmap[start_sample + refined.astype(int), trace_idx]).astype(int), - "confidence", -) -print(picks.picks_in_samples[trace_idx] - start_sample) - -for i in range(len(cand)): - color_p = [c / 255 for c in generate_color()] - plt.scatter([cand[i]], [0], color=color_p) - plt.scatter([refined[i]], [0], marker="*", color=color_p) - -# second_deriv = apply_savgol_filter( -# sub[:, None], window_length=window, polyorder=order, deriv=2 -# ) - -plt.plot(heatmap[start_sample:end_sample, trace_idx], linestyle="--") - -# plt.plot( -# 3 * heatmap[start_sample:end_sample, trace_idx] * first_deriv[:, 0], -# linestyle="dashdot", -# ) - -# plt.plot(np.abs(4 * second_deriv[:, 0]), color="g") -plt.grid() -plt.show() diff --git a/first_breaks/benchmark.py b/first_breaks/benchmark.py index 3bc9135..a0abc53 100644 --- a/first_breaks/benchmark.py +++ b/first_breaks/benchmark.py @@ -1,98 +1,212 @@ +import json +from itertools import product +from os import system from pathlib import Path -from typing import Union, Literal +from typing import List, Optional, Union + +import numpy as np from first_breaks.const import FIRST_BYTE from first_breaks.desktop.graph import export_image from first_breaks.picking.picker_onnx import PickerONNX from first_breaks.picking.picks import Picks +from first_breaks.picking.refiner import MinimalPhaseRefiner from first_breaks.picking.task import Task from first_breaks.sgy.reader import SGY -from first_breaks.utils.utils import ( - download_demo_sgy, - download_by_url, - download_and_validate_file, - calc_hash, -) -import numpy as np +from first_breaks.utils.filtering import apply_savgol_filter +from first_breaks.utils.utils import as_list, calc_hash, download_and_validate_file + + +def download_model_with_heatmap(destination: Union[str, Path]) -> None: + model_hash = "afc03594f49b88ea61b5cf6ba8245be4" + model_url = "https://oml.daloroserver.com/download/seis/fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx" + download_and_validate_file(url=model_url, md5=model_hash, fname=destination) + + +def plot_picks_on_small_section_chunk(sgy: SGY, manual_picks: Picks, predicted_picks: Optional[Picks] = None) -> None: + limit_for_validation = 10 + val_image = "chunk.png" + val_sgy = SGY(source=sgy.read_traces_by_ids(list(range(limit_for_validation))), dt_mcs=sgy.dt_mcs) + val_manual_picks = Picks( + values=manual_picks.picks_in_mcs[:limit_for_validation], + unit="mcs", + dt_mcs=val_sgy.dt_mcs, + color=(255, 0, 0), + ) + picks = [val_manual_picks] + + if predicted_picks: + val_predicted_picks = Picks( + values=predicted_picks.picks_in_mcs[:limit_for_validation], + unit="mcs", + dt_mcs=val_sgy.dt_mcs, + color=(0, 0, 255), + ) + picks.append(val_predicted_picks) + + export_image( + source=val_sgy, + image_filename=val_image, + picks_list=picks, + height=1000, + width=1000, + ) + system(val_image) + -fname = download_demo_sgy() -gain = 1 -traces_per_gather = 12 -maximum_time = 100 - - -# def predict_picks( -# sgy: SGY, -# picker: PickerONNX, -# gain: float, -# traces_per_gather: int, -# maximum_time: float, -# ): -# task = Task( -# source=sgy, -# traces_per_gather=traces_per_gather, -# maximum_time=maximum_time, -# gain=gain, -# ) -# task = picker.process_task(task) -# picks = task.get_result() - - -def make_report( - filename: Union[str, Path], - sgy: SGY, +def calc_snr10(traces: np.ndarray, picks: Picks, smooth: bool = False, symmetric: bool = True) -> List[float]: + if smooth: + traces = apply_savgol_filter(traces, polyorder=3, window_length=11, deriv=0) + + snr = np.ones(traces.shape[1]) + + for idx, pick in enumerate(picks.picks_in_samples): + if pick > 0: + noise = traces[:pick, idx] + if symmetric: + signal_and_noise = traces[pick : pick + len(noise), idx] + else: + signal_and_noise = traces[pick:, idx] + + p_noise = np.mean(np.square(noise)) + p_signal_and_noise = np.mean(np.square(signal_and_noise)) + snr[idx] = (p_signal_and_noise - p_noise) / p_noise + + snr10 = np.log10(snr) + snr10[np.isnan(snr10)] = -1000 + snr10[np.isinf(snr10)] = -2000 + snr10 = snr10.tolist() + + return snr10 + + +def benchmark( + sgy_filename: Union[str, Path], + model_filename: Union[str, Path], + report_filename: Union[str, Path], + gain_list: List[float], + maximum_time_list: List[float], + traces_per_gather_list: List[int], saved_picks_byte_position: int, - predicted_picks: Picks, - gain: float, - traces_per_gather: int, - maximum_time: float, ): + sgy_filename = Path(sgy_filename).resolve() + assert sgy_filename.exists(), f"File {sgy_filename} not found" + sgy = SGY(source=sgy_filename) + print(f"SGY: {sgy_filename}; shape={sgy.shape}, dt_mcs={sgy.dt_mcs}") + assert 1 <= saved_picks_byte_position <= 237 saved_picks = Picks( values=sgy.read_custom_trace_header(saved_picks_byte_position - FIRST_BYTE, "i"), unit="mcs", dt_mcs=sgy.dt_mcs, ) - to_export = {} - - # ANONYMOUS DATA BLOCK - # difference between manual picks and predicted picks is anonymous and expose nothing - difference = (np.array(saved_picks.picks_in_mcs) - np.array(predicted_picks.picks_in_mcs)).astype(int).tolist() - to_export["difference"] = difference + plot_picks_on_small_section_chunk(sgy=sgy, manual_picks=saved_picks) + + download_model_with_heatmap(model_filename) + + report_filename = Path(report_filename) + report_filename.parent.mkdir(exist_ok=True, parents=True) + + picker = PickerONNX(model_path=model_filename, show_progressbar=True) + + to_export = {"confidence": [], "difference": [], "model_hash": picker.model_hash} + + total = len(gain_list) * len(maximum_time_list) * len(traces_per_gather_list) + for idx, (gain, maximum_time, tps) in enumerate(product(gain_list, maximum_time_list, traces_per_gather_list)): + task = Task( + source=sgy, + traces_per_gather=tps, + maximum_time=maximum_time, + gain=gain, + ) + print(f"Task {idx}/{total} started (gain={gain}, max_time={maximum_time}, tps={tps}) ...", flush=True) + task = picker.process_task(task) + predicted_picks = task.get_result() + + confidence = as_list(predicted_picks.confidence) + # difference between manual picks and predicted picks is anonymous and expose nothing, but allows me to compare + # performance with different parameters + difference_raw = ( + (np.array(saved_picks.picks_in_mcs) - np.array(predicted_picks.picks_in_mcs)).astype(int).tolist() + ) + + refined_picks = predicted_picks.create_duplicate() + refiner = MinimalPhaseRefiner() + refined_picks = refiner.refine(sgy=sgy, picks=refined_picks) + + difference_refined = ( + (np.array(saved_picks.picks_in_mcs) - np.array(refined_picks.picks_in_mcs)).astype(int).tolist() + ) + + to_export["confidence"].append( + {"gain": gain, "maximum_time": maximum_time, "traces_per_gather": tps, "values": confidence} + ) + to_export["difference"].append( + { + "gain": gain, + "maximum_time": maximum_time, + "traces_per_gather": tps, + "refined": False, + "values": difference_raw, + } + ) + to_export["difference"].append( + { + "gain": gain, + "maximum_time": maximum_time, + "traces_per_gather": tps, + "refined": True, + "values": difference_refined, + } + ) + + # FILE LEVEL STATS # hash of traces allows me to inderstand if reports were created based on same data or different # without direct access to file: if 2 reports have same `traces_hash` it means that they were calculated based on # same traces, if not - files were different. # So I can understand how different parameters affect specific file analysing `difference` metric for several # files belongs to same `traces_hash` - traces_hash = calc_hash(sgy.read().tobytes(order="C")) + traces = sgy.read() + traces_hash = calc_hash(traces.tobytes(order="C")) to_export["traces_hash"] = traces_hash - channel_hashes = sgy.traces_headers["CHAN"].apply(lambda x: calc_hash(x)[:10]) - source_hashes = sgy.traces_headers["CHAN"].apply(lambda x: calc_hash(x)[:10]) + # I would like to have anonymized base headers to better understand the number of seismic traces for each shot, + # and the number of shots. I want to try to automate the selection of parameter `traces_per_gather` based on this. + # I'm not interested in exact values of these headers, but rather in their distribution, so hashed values + # are sufficient. + for header in ["CHAN", "SOURCE", "FFID"]: + to_export[header] = sgy.traces_headers[header].apply(lambda x: calc_hash(str(x).encode())[:10]).tolist() - # basic shape, not anonymsed - num_samples, num_traces = sgy.shape + to_export["shape"] = sgy.shape + to_export["dt_mcs"] = sgy.dt_mcs + # I want to analyse how picking parameters and result correlate with SNR + to_export["SNR10"] = [] + for smooth, symmetric in product((True, False), (True, False)): + snr10 = calc_snr10(traces, saved_picks, smooth=smooth, symmetric=symmetric) + to_export["SNR10"].append({"smooth": smooth, "symmetric": symmetric, "values": snr10}) -def download_model_with_heatmap(destination: Union[str, Path]): - model_hash = "afc03594f49b88ea61b5cf6ba8245be4" - model_url = "https://oml.daloroserver.com/download/seis/fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx" - download_and_validate_file(url=model_url, md5=model_hash, fname=destination) + with open(report_filename, "w") as f: + json.dump(to_export, f) if __name__ == "__main__": - destination = "model_with_heatmap.onnx" - download_model_with_heatmap(destination) - - # fname = Path(fname) - # assert fname.exists(), f"File {fname.resolve()} not found" - # - # sgy = SGY(fname) - # - # print( - # f"SGY: fname='{fname.resolve()}', num_traces={sgy.num_traces}, num_samples={sgy.num_samples}, dt_mcs={sgy.dt_mcs}" - # ) - # - # picker = PickerONNX() + sgy_filename_ = "my_data.sgy" + model_filename_ = "fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx" + report_filename_ = "report.json" + gain_list_ = [0.1, 0.5, 1] + maximum_time_list_ = [100, 200] + traces_per_gather_list_ = [12] + saved_picks_byte_position_ = 237 + + benchmark( + sgy_filename=sgy_filename_, + model_filename=model_filename_, + report_filename=report_filename_, + gain_list=gain_list_, + maximum_time_list=maximum_time_list_, + traces_per_gather_list=traces_per_gather_list_, + saved_picks_byte_position=saved_picks_byte_position_, + ) diff --git a/first_breaks/const.py b/first_breaks/const.py index 9a571b4..be204c5 100644 --- a/first_breaks/const.py +++ b/first_breaks/const.py @@ -8,10 +8,7 @@ def get_cache_folder() -> Path: if is_linux(): - return ( - Path(environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) - / "first_breaks_picking" - ) + return Path(environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) / "first_breaks_picking" elif is_macos(): return Path.home() / "Library" / "Caches" / "first_breaks_picking" elif is_windows(): @@ -31,7 +28,13 @@ def get_cache_folder() -> Path: MODEL_ONNX_PATH = CACHE_FOLDER / "fb.onnx" MODEL_ONNX_URL = "https://oml.daloroserver.com/download/seis/fb.onnx" MODEL_ONNX_HASH = "7e39e017b01325180e36885eccaeb17a" -MODEL_ONNX_HASHES = [MODEL_ONNX_HASH, "afc03594f49b88ea61b5cf6ba8245be4"] +MODEL_ONNX_HASHES = [ + # MODEL_ONNX_HASH, + "afc03594f49b88ea61b5cf6ba8245be4", + "3930eff8e70b4b29ab8d6def43706918", + "cd5492eae6ed543e9c5206bc18ff8b68", + "86ddd2a20f02201f4b1363abbabf7106", +] TIMEOUT = 60 diff --git a/first_breaks/desktop/graph.py b/first_breaks/desktop/graph.py index 80be383..d55d9a7 100644 --- a/first_breaks/desktop/graph.py +++ b/first_breaks/desktop/graph.py @@ -2,7 +2,7 @@ import os import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, Sequence +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import pyqtgraph as pg @@ -60,12 +60,8 @@ def __init__(self, use_open_gl: bool = True, *args: Any, **kwargs: Any): self.setup_axes() self.spectrum_roi_manager = RoiManager(viewbox=self.getViewBox()) - self.spectrum_window = SpectrumWindow( - use_open_gl=use_open_gl, roi_manager=self.spectrum_roi_manager - ) - self.mouse_click_signal = pg.SignalProxy( - self.sceneObj.sigMouseClicked, rateLimit=60, slot=self.mouse_clicked - ) + self.spectrum_window = SpectrumWindow(use_open_gl=use_open_gl, roi_manager=self.spectrum_roi_manager) + self.mouse_click_signal = pg.SignalProxy(self.sceneObj.sigMouseClicked, rateLimit=60, slot=self.mouse_clicked) def resolve_postime2xy(self, position: Any, time: Any) -> Tuple[Any, Any]: return postime2xy(vsp_view=self.vsp_view, position=position, time=time) @@ -176,16 +172,12 @@ def plotseis( self.getPlotItem().setXRange(min=0, max=x_max) for idx in range(num_traces): - self._plot_trace_fast( - trace=traces[:, idx], time=t, shift=idx + 1, fill_black=fill_black - ) + self._plot_trace_fast(trace=traces[:, idx], time=t, shift=idx + 1, fill_black=fill_black) self.pos_ax.showLabel() self.graph_updated_signal.emit() - def _plot_trace_fast( - self, trace: np.ndarray, time: np.ndarray, shift: int, fill_black: Optional[str] - ) -> None: + def _plot_trace_fast(self, trace: np.ndarray, time: np.ndarray, shift: int, fill_black: Optional[str]) -> None: connect = np.ones(len(time), dtype=np.int32) connect[-1] = 0 @@ -211,9 +203,7 @@ def _plot_trace_fast( sign = -1 if fill_black == "left" else 1 x, y = self.resolve_postime2xy(shift, time[0]) - w, h = self.resolve_postime2xy( - sign * 1.1 * max(np.abs(shifted_trace)), time[-1] - ) + w, h = self.resolve_postime2xy(sign * 1.1 * max(np.abs(shifted_trace)), time[-1]) rect = QPainterPath() rect.addRect(x, y, w, h) @@ -221,9 +211,7 @@ def _plot_trace_fast( patch = path.intersected(rect) item = pg.QtWidgets.QGraphicsPathItem(patch) - pen = QPen( - QColor(255, 255, 255, 0), 1, Qt.SolidLine, Qt.FlatCap, Qt.MiterJoin - ) + pen = QPen(QColor(255, 255, 255, 0), 1, Qt.SolidLine, Qt.FlatCap, Qt.MiterJoin) pen.setCosmetic(True) # 1 pixel width for any scale and resolution item.setPen(pen) item.setBrush(Qt.black) @@ -241,9 +229,7 @@ def _replace_tick_labels(self, *args: Any, **kwargs: Any) -> List[str]: if v % 1 == 0: v = int(v) - 1 if 0 <= v < self.sgy.num_traces: - labels_from_headers.append( - str(self.sgy.traces_headers[self.pos_ax_header].iloc[v]) - ) + labels_from_headers.append(str(self.sgy.traces_headers[self.pos_ax_header].iloc[v])) else: labels_from_headers.append("") else: @@ -289,9 +275,7 @@ def plot_processing_region( # Vertical lines line_time = np.array([0, region_start_time]) - for idx in np.arange( - traces_per_gather + 0.5, num_traces - 1, traces_per_gather - ): + for idx in np.arange(traces_per_gather + 0.5, num_traces - 1, traces_per_gather): line_pos = np.array([idx, idx]) line_x, line_y = self.resolve_postime2xy(line_pos, line_time) line_path = pg.arrayToQPath(line_x, line_y, np.ones(2, dtype=np.int32)) @@ -302,9 +286,7 @@ def plot_processing_region( # Transparent polygon on ignored part poly_pos = np.array([-2, num_traces + 2, num_traces + 2, -2]) - poly_time = np.array( - [region_start_time, region_start_time, sgy_end_time, sgy_end_time] - ) + poly_time = np.array([region_start_time, region_start_time, sgy_end_time, sgy_end_time]) poly_x, poly_y = self.resolve_postime2xy(poly_pos, poly_time) poly_path = pg.arrayToQPath(poly_x, poly_y, np.ones(4, dtype=np.int32)) poly_item = pg.QtWidgets.QGraphicsPathItem(poly_path) @@ -314,9 +296,7 @@ def plot_processing_region( self.addItem(poly_item) def _get_picks_as_item(self, picks: Picks) -> pg.PlotCurveItem: - x, y = self.resolve_postime2xy( - np.arange(self.sgy.num_traces) + 1, np.array(picks.picks_in_ms) - ) + x, y = self.resolve_postime2xy(np.arange(self.sgy.num_traces) + 1, np.array(picks.picks_in_ms)) line = pg.PlotCurveItem() line.setData(x, y) @@ -468,10 +448,7 @@ def export( num_traces = sgy.num_traces width = int(width_per_trace * num_traces) + headers_total_pixels else: - width = ( - int(width_per_trace * (traces_window[1] - traces_window[0])) - + headers_total_pixels - ) + width = int(width_per_trace * (traces_window[1] - traces_window[0])) + headers_total_pixels self.avoid_memory_bomb(height, width) diff --git a/first_breaks/desktop/main_gui.py b/first_breaks/desktop/main_gui.py index 6dffe0b..d05bbc4 100644 --- a/first_breaks/desktop/main_gui.py +++ b/first_breaks/desktop/main_gui.py @@ -23,8 +23,8 @@ DEMO_SGY_PATH, HIGH_DPI, MODEL_ONNX_HASH, - MODEL_ONNX_PATH, MODEL_ONNX_HASHES, + MODEL_ONNX_PATH, ) from first_breaks.data_models.independent import ExceptionOptional from first_breaks.desktop.graph import GraphWidget @@ -37,8 +37,8 @@ SettingsProcessingWidget, ) from first_breaks.desktop.utils import MessageBox, set_geometry -from first_breaks.picking.task import Task from first_breaks.picking.refiner import MinimalPhaseRefiner +from first_breaks.picking.task import Task from first_breaks.sgy.reader import SGY from first_breaks.utils.utils import calc_hash, download_demo_sgy, download_model_onnx @@ -88,9 +88,7 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore else: self.main_folder = Path(__file__).parent - set_geometry( - self, width_rel=0.6, height_rel=0.6, fix_size=False, centralize=True - ) + set_geometry(self, width_rel=0.6, height_rel=0.6, fix_size=False, centralize=True) self.setWindowTitle("First breaks picking") self.threadpool = QThreadPool() @@ -110,33 +108,23 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore icon_get_filename = self.style().standardIcon(QStyle.SP_DirIcon) # icon_get_filename = QIcon(str(self.main_folder / "icons" / "sgy.png")) - self.button_get_filename = QAction( - icon_get_filename, self.TOOLDBAR_OPEN_SGY, self - ) + self.button_get_filename = QAction(icon_get_filename, self.TOOLDBAR_OPEN_SGY, self) self.button_get_filename.triggered.connect(self.get_filename) self.button_get_filename.setEnabled(True) self.toolbar.addAction(self.button_get_filename) self.toolbar.addSeparator() - icon_visual_settings = self.style().standardIcon( - QStyle.SP_FileDialogContentsView - ) - self.button_settings_processing = QAction( - icon_visual_settings, self.TOOLBAR_SETTINGS_AND_PROCESSINGS, self - ) - self.button_settings_processing.triggered.connect( - self.show_settings_processing_window - ) + icon_visual_settings = self.style().standardIcon(QStyle.SP_FileDialogContentsView) + self.button_settings_processing = QAction(icon_visual_settings, self.TOOLBAR_SETTINGS_AND_PROCESSINGS, self) + self.button_settings_processing.triggered.connect(self.show_settings_processing_window) self.button_settings_processing.setEnabled(False) self.toolbar.addAction(self.button_settings_processing) self.need_processing_region = True icon_processing_show = self.style().standardIcon(QStyle.SP_FileDialogListView) # icon_export = QIcon(str(self.main_folder / "icons" / "export.png")) - self.button_processing_show = QAction( - icon_processing_show, self.TOOLBAR_SHOW_GRID, self - ) + self.button_processing_show = QAction(icon_processing_show, self.TOOLBAR_SHOW_GRID, self) self.button_processing_show.triggered.connect(self.processing_region_changed) self.button_processing_show.setChecked(self.need_processing_region) self.button_processing_show.setEnabled(True) @@ -149,9 +137,7 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore icon_picks_manager = self.style().standardIcon(QStyle.SP_FileDialogDetailedView) # icon_export = QIcon(str(self.main_folder / "icons" / "export.png")) - self.button_picks_manager = QAction( - icon_picks_manager, self.TOOLBAR_PICKS_MANAGER, self - ) + self.button_picks_manager = QAction(icon_picks_manager, self.TOOLBAR_PICKS_MANAGER, self) self.button_picks_manager.triggered.connect(self.show_picks_manager) self.button_picks_manager.setEnabled(False) self.toolbar.addAction(self.button_picks_manager) @@ -187,12 +173,8 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore **self.plotseis_settings.model_dump(), ) self.settings_processing_widget.hide() - self.settings_processing_widget.export_plotseis_settings_signal.connect( - self.update_plotseis_settings - ) - self.settings_processing_widget.export_picking_settings_signal.connect( - self.pick_fb - ) + self.settings_processing_widget.export_plotseis_settings_signal.connect(self.update_plotseis_settings) + self.settings_processing_widget.export_picking_settings_signal.connect(self.pick_fb) # nn manager self.nn_manager = NNManager( @@ -202,21 +184,15 @@ def __init__(self, use_open_gl: bool = True, show: bool = True): # type: ignore interrupt_on=self.settings_processing_widget.interrupt_signal, ) self.nn_manager.picking_finished_signal.connect(self.on_picking_finished) - self.nn_manager.picking_not_started_error_signal.connect( - self.on_picking_not_started_error - ) + self.nn_manager.picking_not_started_error_signal.connect(self.on_picking_not_started_error) # picks manager self.picks_manager = PicksManager() self.picks_manager.picks_updated_signal.connect(self.update_plot) self.picks_manager.hide() - self.graph.picks_manual_edited_signal.connect( - self.picks_manager.update_picks_from_external - ) - self.graph.about_to_change_nn_picks_signal.connect( - self.picks_manager.duplicate_active_created_by_nn_picks - ) + self.graph.picks_manual_edited_signal.connect(self.picks_manager.update_picks_from_external) + self.graph.about_to_change_nn_picks_signal.connect(self.picks_manager.duplicate_active_created_by_nn_picks) self.is_toggled_picks_from_file = False # placeholders @@ -308,9 +284,7 @@ def update_plotseis_settings(self, new_settings: PlotseisSettings) -> None: self.update_plot(False) def update_plot(self, refresh_view: bool = False) -> None: - self.graph.plotseis( - self.sgy, refresh_view=refresh_view, **self.plotseis_settings.model_dump() - ) + self.graph.plotseis(self.sgy, refresh_view=refresh_view, **self.plotseis_settings.model_dump()) self.show_processing_region() self.graph.remove_picks() @@ -339,10 +313,7 @@ def load_nn(self, filename: Optional[Union[str, Path]] = None) -> None: ) if filename: - if ( - FileState.get_file_state(filename, self.picker_hashes) - == FileState.valid_file - ): + if FileState.get_file_state(filename, self.picker_hashes) == FileState.valid_file: self.nn_manager.init_net(weights=filename) self.button_load_nn.setEnabled(False) self.ready_to_process.model_loaded = True @@ -398,9 +369,7 @@ def get_filename(self, filename: Optional[Union[str, Path]] = None) -> None: last_folder_manager.set_last_folder(filename) except Exception as e: - window_err = MessageBox( - self, title=e.__class__.__name__, message=str(e) - ) + window_err = MessageBox(self, title=e.__class__.__name__, message=str(e)) window_err.exec_() def show_picks_manager(self) -> None: diff --git a/first_breaks/desktop/picks_manager_widget.py b/first_breaks/desktop/picks_manager_widget.py index 1cc2014..7a9f9dc 100644 --- a/first_breaks/desktop/picks_manager_widget.py +++ b/first_breaks/desktop/picks_manager_widget.py @@ -57,9 +57,7 @@ def __init__(self) -> None: layout.addLayout(input_layout) - self.button_box = QDialogButtonBox( - QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self - ) + self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self) self.button_box.accepted.connect(self.accept) self.button_box.rejected.connect(self.reject) @@ -93,9 +91,7 @@ def __init__(self) -> None: layout.addWidget(self.combo_box) - self.button_box = QDialogButtonBox( - QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self - ) + self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self) self.button_box.accepted.connect(self.accept) self.button_box.rejected.connect(self.reject) @@ -123,9 +119,7 @@ def __init__(self, text: str = "", color: QColor = QColor(255, 255, 255)) -> Non self.color_display = QLabel(self) # Using QLabel to display color self.color_display.setFixedSize(20, 20) # Fixed size for color display - self.color_display.setStyleSheet( - f"background-color: {color.name()}; border: 1px solid black;" - ) + self.color_display.setStyleSheet(f"background-color: {color.name()}; border: 1px solid black;") self.currentColor = color # Store the current color layout = QHBoxLayout() @@ -153,9 +147,7 @@ def on_radiobutton_clicked(self, checked: bool) -> None: def edit_color(self, event: QEvent) -> None: new_color = QColorDialog.getColor(self.currentColor, self) if new_color.isValid(): - self.color_display.setStyleSheet( - f"background-color: {new_color.name()}; border: 1px solid black;" - ) + self.color_display.setStyleSheet(f"background-color: {new_color.name()}; border: 1px solid black;") self.currentColor = new_color self.color_changed_signal.emit(new_color) # Emit signal with the new color @@ -175,9 +167,7 @@ def __init__( txt_exporter_kwargs = txt_exporter_kwargs or {} json_exporter_kwargs = json_exporter_kwargs or {} - set_geometry( - self, width_rel=0.3, height_rel=0.3, centralize=True, fix_size=False - ) + set_geometry(self, width_rel=0.3, height_rel=0.3, centralize=True, fix_size=False) self.picks_item_widget = picks_item_widget self.picks_mapping = picks_mapping @@ -216,9 +206,7 @@ def __init__( self.tab_all.addTab(self.tab_export, "Export") layout.addWidget(self.tab_all) - self.button_box = QDialogButtonBox( - QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self - ) + self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self) self.button_box.accepted.connect(self.accept) self.button_box.rejected.connect(self.reject) @@ -302,9 +290,7 @@ def __init__(self) -> None: self.list_widget = QListWidget(self) self.list_widget.itemDoubleClicked.connect(self.open_properties) - self.list_widget.itemSelectionChanged.connect( - self.update_properties_button_state - ) + self.list_widget.itemSelectionChanged.connect(self.update_properties_button_state) # Enable multi-selection mode self.list_widget.setSelectionMode(QListWidget.ExtendedSelection) @@ -317,12 +303,8 @@ def __init__(self) -> None: self.remove_button = QPushButton("-", self) self.remove_button.clicked.connect(self.remove_items) # renamed for clarity # self.properties_button = QPushButton("\u2699", self) # Unicode character for gear - self.properties_button = QPushButton( - "\U0001F4BE", self - ) # "\U0001F4BE" or "\U0001F5AB" - self.properties_button.setFont( - self.font() - ) # to increase the size of the button a bit + self.properties_button = QPushButton("\U0001F4BE", self) # "\U0001F4BE" or "\U0001F5AB" + self.properties_button.setFont(self.font()) # to increase the size of the button a bit self.properties_button.clicked.connect(self.open_properties) button_layout.addWidget(self.add_button) button_layout.addWidget(self.remove_button) @@ -419,34 +401,23 @@ def add_constant_values_pick(self) -> Optional[PicksItemWidget]: created_by_nn=False, ) self.items_counter.constant_values += 1 - return self.add_picks( - picks, f"Manual {self.items_counter.constant_values}" - ) + return self.add_picks(picks, f"Manual {self.items_counter.constant_values}") return None else: return None - def add_duplicate_pick( - self, selected_picks_item: Optional[QListWidgetItem] = None - ) -> PicksItemWidget: - selected_picks_item = selected_picks_item or self.list_widget.itemWidget( - self.list_widget.selectedItems()[0] - ) + def add_duplicate_pick(self, selected_picks_item: Optional[QListWidgetItem] = None) -> PicksItemWidget: + selected_picks_item = selected_picks_item or self.list_widget.itemWidget(self.list_widget.selectedItems()[0]) picks = self.picks_mapping[selected_picks_item] duplicated_picks = picks.create_duplicate(keep_color=False) self.items_counter.duplicated += 1 - return self.add_picks( - duplicated_picks, f"Duplicated from '{selected_picks_item.get_name()}'" - ) + return self.add_picks(duplicated_picks, f"Duplicated from '{selected_picks_item.get_name()}'") def add_aggregate_pick(self) -> Optional[PicksItemWidget]: self.items_counter.aggregated += 1 - selected_picks_items = [ - self.list_widget.itemWidget(item) - for item in self.list_widget.selectedItems() - ] + selected_picks_items = [self.list_widget.itemWidget(item) for item in self.list_widget.selectedItems()] dialog = AggregationDialog() self._aggregation_widget = dialog @@ -474,9 +445,7 @@ def add_aggregate_pick(self) -> Optional[PicksItemWidget]: return None def add_from_headers_pick(self) -> Optional[PicksItemWidget]: - dialog = QDialogByteEncodeUnit( - byte_position=1, first_byte=FIRST_BYTE, encoding="I", picks_unit="mcs" - ) + dialog = QDialogByteEncodeUnit(byte_position=1, first_byte=FIRST_BYTE, encoding="I", picks_unit="mcs") self._load_from_headers_widget = dialog if dialog.exec_() == QDialog.Accepted: @@ -512,9 +481,7 @@ def add_picks(self, picks: Picks, name: str) -> PicksItemWidget: item.setSizeHint(picks_item_widget.sizeHint()) picks_item_widget.color_changed_signal.connect( - lambda color, widget=picks_item_widget: self.update_picks_color( - widget, color - ) + lambda color, widget=picks_item_widget: self.update_picks_color(widget, color) ) picks_item_widget.checkbox.clicked.connect(self.picks_updated_signal) picks_item_widget.checkbox.setChecked(True) @@ -584,9 +551,7 @@ def update_properties_button_state(self) -> None: selected_items = self.list_widget.selectedItems() self.properties_button.setEnabled(len(selected_items) == 1) - def update_picks_color( - self, picks_item_widget: PicksItemWidget, color: QColor - ) -> None: + def update_picks_color(self, picks_item_widget: PicksItemWidget, color: QColor) -> None: pick = self.picks_mapping.get(picks_item_widget) if pick: pick.color = (color.red(), color.green(), color.blue()) diff --git a/first_breaks/picking/picker_onnx.py b/first_breaks/picking/picker_onnx.py index 2edf59e..b07b086 100644 --- a/first_breaks/picking/picker_onnx.py +++ b/first_breaks/picking/picker_onnx.py @@ -23,10 +23,7 @@ class IteratorOfTask: def __init__(self, task: Task): self.task = task - self.idx2gather_ids = { - idx: gather_ids - for idx, gather_ids in enumerate(self.task.get_gathers_ids()) - } + self.idx2gather_ids = {idx: gather_ids for idx, gather_ids in enumerate(self.task.get_gathers_ids())} if self.task.normalize == "gather" and len(self.idx2gather_ids) > 1: raise AssertionError( "'gather' normalization can't be used for picking when number of gathers > 1. " @@ -39,10 +36,7 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: gather_ids = self.idx2gather_ids[idx] amplitudes = np.array( - [ - -1 if idx in self.task.traces_to_inverse else 1 - for idx in range(len(gather_ids)) - ], + [-1 if idx in self.task.traces_to_inverse else 1 for idx in range(len(gather_ids))], dtype=np.float32, ) gather = self.task.sgy.read_traces_by_ids(gather_ids) @@ -62,9 +56,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: return {self.gather_key: gather, self.gather_ids_key: np.array(gather_ids)} - def get_batch_generator( - self, batch_size: int = 1 - ) -> Generator[Dict[str, np.ndarray], None, None]: + def get_batch_generator(self, batch_size: int = 1) -> Generator[Dict[str, np.ndarray], None, None]: for ids in chunk_iterable(range(len(self)), batch_size): gather_batch = [] gather_ids_batch = [] @@ -132,9 +124,7 @@ def change_settings( # type: ignore return self - def pick_batch_of_gathers( - self, gather: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + def pick_batch_of_gathers(self, gather: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: assert gather.ndim == 4 assert all(dim > 0 for dim in gather.shape) outputs = self.model.run(["picks", "confs", "heatmap"], {"input": gather}) @@ -149,9 +139,7 @@ def process_task(self, task: Task) -> Task: counter_step_finished = 0 self.callback_processing_started(len(task_iterator)) - for idx, batch in enumerate( - task_iterator.get_batch_generator(batch_size=self.batch_size) - ): + for idx, batch in enumerate(task_iterator.get_batch_generator(batch_size=self.batch_size)): self.interrupt_if_need() data = batch["gather"].astype(np.float32) picks, confidence, heatmap = self.pick_batch_of_gathers(data) diff --git a/first_breaks/picking/picks.py b/first_breaks/picking/picks.py index 222b92f..cb64ab7 100644 --- a/first_breaks/picking/picks.py +++ b/first_breaks/picking/picks.py @@ -69,10 +69,7 @@ def _sync_units_converter_with_dt_mcs(self) -> "Picks": self.model_config["validate_assignment"] = False if self.dt_mcs is not None: - if ( - self._units_converter is None - or self._units_converter.sgy_mcs != self.dt_mcs - ): + if self._units_converter is None or self._units_converter.sgy_mcs != self.dt_mcs: self._units_converter = UnitsConverter(sgy_mcs=self.dt_mcs) else: self._units_converter = None diff --git a/first_breaks/picking/refiner.py b/first_breaks/picking/refiner.py index f4acbe3..d307363 100644 --- a/first_breaks/picking/refiner.py +++ b/first_breaks/picking/refiner.py @@ -34,17 +34,13 @@ def find_extrema_mask(data: np.ndarray, neighbor_range: int = 1) -> np.ndarray: return extrema_mask -def calc_intersection( - data: np.ndarray, data_derivative: np.ndarray, tangent_points: np.ndarray -) -> np.ndarray: +def calc_intersection(data: np.ndarray, data_derivative: np.ndarray, tangent_points: np.ndarray) -> np.ndarray: slope = data_derivative[tangent_points] intercept = data[tangent_points] - slope * tangent_points return -intercept / slope -def calc_intersection_vectorized( - data: np.ndarray, data_derivative: np.ndarray, extrema_mask: np.ndarray -): +def calc_intersection_vectorized(data: np.ndarray, data_derivative: np.ndarray, extrema_mask: np.ndarray): assert all(arr.ndim == 2 for arr in [data, data_derivative, extrema_mask]) assert extrema_mask.dtype == np.bool_ @@ -56,16 +52,10 @@ def calc_intersection_vectorized( sorted_col_indices = col_indices[sorted_indices] slope = data_derivative[sorted_row_indices, sorted_col_indices] - intercept = ( - data[sorted_row_indices, sorted_col_indices] - slope * sorted_row_indices - ) + intercept = data[sorted_row_indices, sorted_col_indices] - slope * sorted_row_indices intersection = -intercept / slope - to_keep = ( - (intersection >= 0) - & (intersection < (len(data) - 1)) - & (intersection != np.inf) - ) + to_keep = (intersection >= 0) & (intersection < (len(data) - 1)) & (intersection != np.inf) intersection = intersection[to_keep] sorted_col_indices = sorted_col_indices[to_keep] @@ -152,9 +142,7 @@ def refine(self, sgy: SGY, picks: Picks) -> Picks: width_after=self.analyse_window_after, ) - extrema = find_extrema_mask( - data=first_derivateive[band_mask], neighbor_range=self.extrema_window - ) + extrema = find_extrema_mask(data=first_derivateive[band_mask], neighbor_range=self.extrema_window) tr2intersections = calc_intersection_vectorized( data=filtered[band_mask], @@ -163,10 +151,7 @@ def refine(self, sgy: SGY, picks: Picks) -> Picks: ) # previous intersections obtained based on band data. We need to map band intersections to initail coordintates - tr2intersections = { - tr: band_mask[0][inter.round().astype(int), tr] - for tr, inter in tr2intersections.items() - } + tr2intersections = {tr: band_mask[0][inter.round().astype(int), tr] for tr, inter in tr2intersections.items()} refined_picks = refine_picks( raw_picks=picks_in_samples, diff --git a/first_breaks/utils/filtering.py b/first_breaks/utils/filtering.py index cd22c50..588481a 100644 --- a/first_breaks/utils/filtering.py +++ b/first_breaks/utils/filtering.py @@ -18,9 +18,10 @@ def savgol_coeffs(window_length: int, polyorder: int, deriv: int = 0) -> np.ndar return coeffs -def apply_savgol_filter( - data: np.ndarray, window_length: int, polyorder: int, deriv: int = 0 -) -> np.ndarray: +def apply_savgol_filter(data: np.ndarray, window_length: int, polyorder: int, deriv: int = 0) -> np.ndarray: + """ + https://en.wikipedia.org/wiki/Savitzky%E2%80%93Golay_filter + """ assert 1 <= data.ndim <= 2 coeffs = savgol_coeffs(window_length, polyorder, deriv) @@ -34,7 +35,5 @@ def apply_savgol_filter( padded_data = np.pad(data, padding, mode=pad_mode) - filtered_data = np.apply_along_axis( - lambda m: np.convolve(m, coeffs, mode="valid"), axis=0, arr=padded_data - ) + filtered_data = np.apply_along_axis(lambda m: np.convolve(m, coeffs, mode="valid"), axis=0, arr=padded_data) return filtered_data diff --git a/first_breaks/utils/utils.py b/first_breaks/utils/utils.py index d00f402..ac0d389 100644 --- a/first_breaks/utils/utils.py +++ b/first_breaks/utils/utils.py @@ -35,9 +35,7 @@ def chunk_iterable(it: Iterable[Any], size: int) -> List[Tuple[Any, ...]]: return list(iter(lambda: tuple(islice(it, size)), ())) -def get_io( - source: Union[Path, str, bytes], mode: str = "r" -) -> Union[io.BytesIO, io.FileIO]: +def get_io(source: Union[Path, str, bytes], mode: str = "r") -> Union[io.BytesIO, io.FileIO]: if isinstance(source, (Path, str)): source = Path(source).resolve() if "r" in mode: @@ -55,8 +53,10 @@ def calc_hash(source: Union[Path, str, bytes, io.BytesIO, io.FileIO]) -> str: hash_md5 = hashlib.md5() if not isinstance(source, (io.BytesIO, io.FileIO)): source = get_io(source, mode="rb") + source.seek(0) for chunk in iter(lambda: source.read(4096), b""): # type: ignore hash_md5.update(chunk) + source.close() return hash_md5.hexdigest() @@ -75,9 +75,7 @@ def calc_hash(source: Union[Path, str, bytes, io.BytesIO, io.FileIO]) -> str: # return response.content -def download_by_url( - url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT -) -> bytes: +def download_by_url(url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT) -> bytes: response = requests.get(url, stream=True, timeout=timeout) response.raise_for_status() total_size = int(response.headers.get("content-length", 0)) @@ -114,9 +112,7 @@ def download_and_validate_file( download_by_url(url=url, fname=fname, timeout=timeout) md5_last = calc_hash(fname) if md5_last != md5: - raise InvalidHash( - f"Hash for file {Path(fname).resolve()} in invalid. Got {md5_last}, expected {md5}" - ) + raise InvalidHash(f"Hash for file {Path(fname).resolve()} in invalid. Got {md5_last}, expected {md5}") return fname @@ -136,9 +132,7 @@ def download_model_onnx( return download_and_validate_file(fname=fname, url=url, md5=md5) -def multiply_iterable_by( - sample: TTimeType, multiplier: float, cast_to: Optional[Any] = None -) -> TTimeType: +def multiply_iterable_by(sample: TTimeType, multiplier: float, cast_to: Optional[Any] = None) -> TTimeType: if isinstance(sample, (int, float, str)): result = sample * multiplier # type: ignore return cast_to(result) if cast_to is not None else result @@ -161,15 +155,9 @@ def __init__( sgy_ms: Optional[Union[int, float]] = None, ): if args: - raise ValueError( - "Specify explicitly either `sgy_mcs`or `sgy_ms` as keyword argument" - ) - if (sgy_mcs is None and sgy_ms is None) or ( - sgy_mcs is not None and sgy_ms is not None - ): - raise RuntimeError( - "One and only one of `sgy_mcs` or `sgy_ms` must be specified" - ) + raise ValueError("Specify explicitly either `sgy_mcs`or `sgy_ms` as keyword argument") + if (sgy_mcs is None and sgy_ms is None) or (sgy_mcs is not None and sgy_ms is not None): + raise RuntimeError("One and only one of `sgy_mcs` or `sgy_ms` must be specified") elif sgy_mcs is not None: self.sgy_mcs = sgy_mcs self.sgy_ms = self.mcs2ms(sgy_mcs) # type: ignore @@ -201,11 +189,7 @@ def index2mcs(self, sample: TTimeType, cast_to: Any = int) -> TTimeType: def remove_unused_kwargs(kwargs: Dict[str, Any], constructor: Any) -> Dict[str, Any]: - return { - k: v - for k, v in kwargs.items() - if k in inspect.signature(constructor).parameters - } + return {k: v for k, v in kwargs.items() if k in inspect.signature(constructor).parameters} def _color_generator() -> Generator[List[int], None, None]: diff --git a/first_breaks/utils/visualizations.py b/first_breaks/utils/visualizations.py index 3fa6c9a..bf6574d 100644 --- a/first_breaks/utils/visualizations.py +++ b/first_breaks/utils/visualizations.py @@ -99,9 +99,7 @@ def plotseis( tail = np.array((k_trace + 1, num_time * dt))[np.newaxis, :] patch_data = np.vstack((head, patch_data, tail)) - polygon = Polygon( - patch_data, closed=True, facecolor="black", edgecolor=None - ) + polygon = Polygon(patch_data, closed=True, facecolor="black", edgecolor=None) ax.add_patch(polygon) if picking is not None: From d8cf37e3f7c66a03e55e1a21b058c5b0c315eecf Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Sun, 5 Oct 2025 18:29:46 +0200 Subject: [PATCH 3/9] refiner and small changes --- .gitignore | 1 + README.md | 2 +- first_breaks/const.py | 7 +- first_breaks/data_models/independent.py | 2 +- first_breaks/desktop/main_gui.py | 12 +- first_breaks/picking/picker_onnx.py | 42 ++++-- first_breaks/picking/picks.py | 6 +- first_breaks/picking/refiner.py | 150 ---------------------- first_breaks/utils/utils.py | 30 ++--- tests/test_common/test_readme_examples.py | 2 +- 10 files changed, 51 insertions(+), 203 deletions(-) diff --git a/.gitignore b/.gitignore index 03e120d..6018d04 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,4 @@ dmypy.json fb.onnx local_dev +/debug/ diff --git a/README.md b/README.md index bcc9442..b54d53b 100644 --- a/README.md +++ b/README.md @@ -476,7 +476,7 @@ picks_ms = np.random.uniform(low=0, size=sgy.ntr) picks = Picks(values=picks_ms, unit="ms", dt_mcs=sgy.dt_mcs, color=(0, 100, 100)) export_image(sgy, image_filename, - picks=picks) + picks_list=[picks]) ``` [code-block-end]:plot-sgy-custom-picks diff --git a/first_breaks/const.py b/first_breaks/const.py index be204c5..cc507f4 100644 --- a/first_breaks/const.py +++ b/first_breaks/const.py @@ -29,11 +29,8 @@ def get_cache_folder() -> Path: MODEL_ONNX_URL = "https://oml.daloroserver.com/download/seis/fb.onnx" MODEL_ONNX_HASH = "7e39e017b01325180e36885eccaeb17a" MODEL_ONNX_HASHES = [ - # MODEL_ONNX_HASH, - "afc03594f49b88ea61b5cf6ba8245be4", - "3930eff8e70b4b29ab8d6def43706918", - "cd5492eae6ed543e9c5206bc18ff8b68", - "86ddd2a20f02201f4b1363abbabf7106", + MODEL_ONNX_HASH, + "afc03594f49b88ea61b5cf6ba8245be4", # model with heatmap ] TIMEOUT = 60 diff --git a/first_breaks/data_models/independent.py b/first_breaks/data_models/independent.py index f164a14..163652e 100644 --- a/first_breaks/data_models/independent.py +++ b/first_breaks/data_models/independent.py @@ -5,7 +5,7 @@ import numpy as np from pydantic import UUID4, BaseModel, Field, field_validator -TColor = Union[Tuple[int, int, int, int], Tuple[int, int, int]] +TColor = Union[Tuple[int, int, int], Tuple[int, int, int, int], Tuple[int, ...]] TNormalize = Union[Literal["trace", "gather"], float, int, np.ndarray, None] diff --git a/first_breaks/desktop/main_gui.py b/first_breaks/desktop/main_gui.py index d05bbc4..13a8d28 100644 --- a/first_breaks/desktop/main_gui.py +++ b/first_breaks/desktop/main_gui.py @@ -37,7 +37,6 @@ SettingsProcessingWidget, ) from first_breaks.desktop.utils import MessageBox, set_geometry -from first_breaks.picking.refiner import MinimalPhaseRefiner from first_breaks.picking.task import Task from first_breaks.sgy.reader import SGY from first_breaks.utils.utils import calc_hash, download_demo_sgy, download_model_onnx @@ -59,7 +58,6 @@ def get_file_state(cls, fname: Union[str, Path], fhashes: str) -> int: if not Path(fname).is_file(): return cls.file_not_exists else: - print(calc_hash(fname), fhashes) return cls.valid_file if calc_hash(fname) in fhashes else cls.file_changed @@ -229,11 +227,6 @@ def on_picking_finished(self, result: Task) -> None: if result.success: self.picks_manager.add_nn_picks(result.picks) - refined_picks = result.picks.create_duplicate() - refiner = MinimalPhaseRefiner() - refined_picks = refiner.refine(self.sgy, refined_picks) - self.picks_manager.add_picks(refined_picks, "Refined picks") - self.update_plot(refresh_view=False) self.run_processing_region() else: @@ -396,13 +389,10 @@ def run_app() -> None: def fetch_data_and_run_app() -> None: - from first_breaks.const import PROJECT_ROOT - download_model_onnx(MODEL_ONNX_PATH) download_demo_sgy(DEMO_SGY_PATH) app, window = create_app() - # window.load_nn(MODEL_ONNX_PATH) - window.load_nn(PROJECT_ROOT / "fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx") + window.load_nn(MODEL_ONNX_PATH) window.get_filename(DEMO_SGY_PATH) app.exec_() diff --git a/first_breaks/picking/picker_onnx.py b/first_breaks/picking/picker_onnx.py index b07b086..e42bdc7 100644 --- a/first_breaks/picking/picker_onnx.py +++ b/first_breaks/picking/picker_onnx.py @@ -75,6 +75,10 @@ def get_batch_generator(self, batch_size: int = 1) -> Generator[Dict[str, np.nda class PickerONNX(IPicker): + OUTPUT_PICKS_KEY = "picks" + OUTPUT_CONFS_KEY = "confs" + OUTPUT_HEATMAP_KEY = "heatmap" + def __init__( self, model_path: Optional[Union[str, Path]] = None, @@ -96,6 +100,15 @@ def __init__( self.model: Optional[ort.InferenceSession] = None self.init_model() + self._available_outputs = sorted(o.name for o in self.model.get_outputs()) + self._input_name = self.model.get_inputs()[0].name + + for mandatory_key in [self.OUTPUT_PICKS_KEY, self.OUTPUT_CONFS_KEY]: + assert mandatory_key in self._available_outputs, f"`mandatory_key` not found in model outputs `{self._available_outputs}`" + + def is_heatmap_available(self) -> bool: + return self.OUTPUT_HEATMAP_KEY in self._available_outputs + def init_model(self) -> None: sess_opt = ort.SessionOptions() sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL @@ -124,16 +137,20 @@ def change_settings( # type: ignore return self - def pick_batch_of_gathers(self, gather: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + def pick_batch_of_gathers(self, gather: np.ndarray) -> Dict[str, np.ndarray]: assert gather.ndim == 4 assert all(dim > 0 for dim in gather.shape) - outputs = self.model.run(["picks", "confs", "heatmap"], {"input": gather}) - return outputs[0], outputs[1], outputs[2] + outputs_list = self.model.run(self._available_outputs, {self._input_name: gather}) + return dict(zip(self._available_outputs, outputs_list)) def process_task(self, task: Task) -> Task: task_picks_in_sample = np.zeros(task.sgy.num_traces) task_confidence = np.zeros(task.sgy.num_traces) - task_heatmap = np.zeros((task.sgy.num_samples, task.sgy.num_traces)) + + if self.is_heatmap_available(): + task_heatmap = np.zeros((task.sgy.num_samples, task.sgy.num_traces)) + else: + task_heatmap = None task_iterator = IteratorOfTask(task) counter_step_finished = 0 @@ -142,15 +159,22 @@ def process_task(self, task: Task) -> Task: for idx, batch in enumerate(task_iterator.get_batch_generator(batch_size=self.batch_size)): self.interrupt_if_need() data = batch["gather"].astype(np.float32) - picks, confidence, heatmap = self.pick_batch_of_gathers(data) + results = self.pick_batch_of_gathers(data) + + picks = results[self.OUTPUT_PICKS_KEY] + confidence = results[self.OUTPUT_CONFS_KEY] + indices = batch["gather_ids"] task_picks_in_sample[indices.flatten()] = picks.flatten() # (B, W) -> (BxW) task_confidence[indices.flatten()] = confidence.flatten() # (B, W) -> (BxW) - heatmap = heatmap.transpose(1, 0, 2) # (B, H, W) -> (H, B, W) - heatmap = heatmap.reshape(heatmap.shape[0], -1) # (H, B, W) -> (H, BxW) - task_heatmap[: len(heatmap), indices.flatten()] = heatmap + + if self.is_heatmap_available(): + heatmap = results[self.OUTPUT_HEATMAP_KEY] + heatmap = heatmap.transpose(1, 0, 2) # (B, H, W) -> (H, B, W) + heatmap = heatmap.reshape(heatmap.shape[0], -1) # (H, B, W) -> (H, BxW) + task_heatmap[: len(heatmap), indices.flatten()] = heatmap counter_step_finished += len(data) self.callback_step_finished(counter_step_finished) @@ -166,7 +190,7 @@ def process_task(self, task: Task) -> Task: modified_manually=False, created_manually=False, created_by_nn=True, - picks_color=generate_color(), + color=generate_color(), picking_parameters=task.picking_parameters, ) diff --git a/first_breaks/picking/picks.py b/first_breaks/picking/picks.py index cb64ab7..71a7c44 100644 --- a/first_breaks/picking/picks.py +++ b/first_breaks/picking/picks.py @@ -1,5 +1,5 @@ import uuid -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional, Union, Annotated import numpy as np from pydantic import UUID4, Field, model_validator @@ -49,7 +49,7 @@ class Picks(DefaultModel): modified_manually: Optional[bool] = None picking_parameters: Optional[PickingParameters] = None color: TColor = Field(default_factory=generate_color, description="Color for picks") # type: ignore - width: float = Field(DEFAULT_PICKS_WIDTH, description="Width of pick line") + width: Annotated[float, Field(description="Width of pick line")] = DEFAULT_PICKS_WIDTH active: Optional[bool] = None @@ -165,5 +165,5 @@ def create_duplicate(self, keep_color: bool = False) -> "Picks": created_by_nn=self.created_by_nn, modified_manually=True, picking_parameters=self.picking_parameters, - picks_color=self.color if keep_color else generate_color(), + color=self.color if keep_color else generate_color(), ) diff --git a/first_breaks/picking/refiner.py b/first_breaks/picking/refiner.py index d307363..0cbc7df 100644 --- a/first_breaks/picking/refiner.py +++ b/first_breaks/picking/refiner.py @@ -86,12 +86,6 @@ def refine_picks( for trace, intersections in traces2intersections.items(): intersections_int = intersections.astype(int) prob = probability_heatmap[intersections_int, trace] - # print( - # trace, - # intersections_int, - # prob, - # [raw_picks[trace], probability_heatmap[raw_picks[trace], trace]], - # ) best_candidate = np.argmax(prob) if prob[best_candidate] > minimum_probability_to_refine: refined_picks[trace] = intersections_int[best_candidate] @@ -163,147 +157,3 @@ def refine(self, sgy: SGY, picks: Picks) -> Picks: picks.from_samples(refined_picks) return picks - - -if __name__ == "__main__": - sgy = SGY(PROJECT_ROOT / "with_picks.sgy") - heatmap = np.load(PROJECT_ROOT / "heatmap.npy") - src_picks = Picks( - values=sgy.read_custom_trace_header(236, "i"), - unit="mcs", - dt_mcs=sgy.dt_mcs, - heatmap=heatmap, - ) - - new_picks = src_picks.create_duplicate() - - refiner = MinimalPhaseRefiner() - with Performance(): - new_picks = refiner.refine(sgy, new_picks) - - print(src_picks.picks_in_samples) - print(new_picks.picks_in_samples) - - # num_tr = 20 - # num_samples = 20 - # window_smooth = 11 - # order = 3 - # window_extrema = 3 - # window_analyse_before = 5 - # window_analyse_after = 5 - # min_probability_to_refine = 0.9 - # - # - # raw = np.random.uniform(size=(num_samples, num_tr)) - # picks = np.random.randint(0, num_samples, size=num_tr).astype(int) - # heatmap = np.random.randint(1, 3, size=raw.shape) - # - # with Performance(): - # filtered = apply_savgol_filter( - # data=raw, polyorder=order, window_length=window_smooth, deriv=0 - # ) - # first_derivateive = apply_savgol_filter( - # data=raw, polyorder=order, window_length=window_smooth, deriv=1 - # ) - # - # band_mask = get_band_mask( - # data=raw, - # band_ids=picks, - # width_before=window_analyse_before, - # width_after=window_analyse_after, - # ) - # - # extrema = find_extrema_mask( - # data=first_derivateive[band_mask], neighbor_range=window_extrema - # ) - # - # tr2intersections = calc_intersection_vectorized( - # data=filtered[band_mask], - # data_derivative=-first_derivateive[band_mask], - # extrema_mask=extrema, - # ) - # # pprint(band_mask[0]) - # # pprint(picks) - # # pprint(tr2intersections) - # # band_start = band_mask[0][0, :] - # # pprint(band_start) - # # tr2intersections = { - # # tr: inter + band_start[tr] for tr, inter in tr2intersections.items() - # # } - # - # # pprint(band_mask[0]) - # # pprint(tr2intersections) - # - # # previous intersections obtained based on band data. We need to map band intersections to initail coordintates - # tr2intersections = { - # tr: band_mask[0][inter.round().astype(int), tr] - # for tr, inter in tr2intersections.items() - # } - # - # refined_picks = refine_picks( - # raw_picks=picks, - # probability_heatmap=heatmap, - # traces2intersections=tr2intersections, - # minimum_probability_to_refine=min_probability_to_refine, - # ) - - # d = np.random.uniform(size=(20, 20)) - # der = np.random.uniform(size=(20, 20)) - # - # - # picks = np.array([5] * 20) - # - # # picks[3] = 1 - # - # d[picks - 3 : picks + 3, :] - # - # - # d[:, 1:10] = 100 - # # tang = [] - # - # # aa = np.random.randint(0, 2, size=(3, 3)).astype(bool) - # # print(aa) - # # print(aa.nonzero()) - # - # with Performance(): - # extrema = find_extrema_mask(d) - # - # - # # print(res) - # # print(np.where(res)) - # - # with Performance(): - # tr2intersection = {} - # - # res = np.where(extrema) - # - # for i in np.unique(res[1]): - # extrema_tr = res[0][res[1] == i] - # tr2intersection[i] = calc_intersection(d[:, i], der[:, i], extrema_tr) - # - # - # print(extrema.shape) - # - # with Performance(): - # v = calc_intersection_vectorized(d, der, extrema) - # - # - # # print(len(v), len(tr2intersection)) - # # - # # print(v[0]) - # # print(tr2intersection[0]) - # - # - # assert all(np.allclose(tr2intersection[i], v[i]) for i in tr2intersection.keys()) - # - # print(tr2intersection[5]) - - # d = np.arange(10)[:, None] - # d = np.tile(d, (1, 5)) - # picks = np.array([1, 9, 3, 4, 5]) - # - # band_mask = get_band_mask(d, picks, 3, 2) - # - # print(d) - # print(band_mask) - # print(d[band_mask]) diff --git a/first_breaks/utils/utils.py b/first_breaks/utils/utils.py index ac0d389..3ce284d 100644 --- a/first_breaks/utils/utils.py +++ b/first_breaks/utils/utils.py @@ -35,8 +35,10 @@ def chunk_iterable(it: Iterable[Any], size: int) -> List[Tuple[Any, ...]]: return list(iter(lambda: tuple(islice(it, size)), ())) -def get_io(source: Union[Path, str, bytes], mode: str = "r") -> Union[io.BytesIO, io.FileIO]: - if isinstance(source, (Path, str)): +def get_io(source: Union[Path, str, bytes, io.BytesIO, io.FileIO], mode: str = "r") -> Union[io.BytesIO, io.FileIO]: + if isinstance(source, (io.BytesIO, io.FileIO)): + return source + elif isinstance(source, (Path, str)): source = Path(source).resolve() if "r" in mode: if not source.exists(): @@ -51,8 +53,7 @@ def get_io(source: Union[Path, str, bytes], mode: str = "r") -> Union[io.BytesIO def calc_hash(source: Union[Path, str, bytes, io.BytesIO, io.FileIO]) -> str: hash_md5 = hashlib.md5() - if not isinstance(source, (io.BytesIO, io.FileIO)): - source = get_io(source, mode="rb") + source = get_io(source, mode="rb") source.seek(0) for chunk in iter(lambda: source.read(4096), b""): # type: ignore hash_md5.update(chunk) @@ -60,21 +61,6 @@ def calc_hash(source: Union[Path, str, bytes, io.BytesIO, io.FileIO]) -> str: return hash_md5.hexdigest() -# def download_by_url( -# url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT -# ) -> Optional[bytes]: -# response = requests.get(url, timeout=timeout, stream=True) -# if response.status_code != 200: -# response.raise_for_status() -# return None -# else: -# if fname: -# Path(fname).parent.mkdir(exist_ok=True, parents=True) -# with open(fname, "wb+") as f: -# f.write(response.content) -# return response.content - - def download_by_url(url: str, fname: Optional[Union[str, Path]], timeout: float = TIMEOUT) -> bytes: response = requests.get(url, stream=True, timeout=timeout) response.raise_for_status() @@ -192,19 +178,19 @@ def remove_unused_kwargs(kwargs: Dict[str, Any], constructor: Any) -> Dict[str, return {k: v for k, v in kwargs.items() if k in inspect.signature(constructor).parameters} -def _color_generator() -> Generator[List[int], None, None]: +def _color_generator() -> Generator[Tuple[int, ...], None, None]: golden_ratio = 0.618033988749895 hue = random.random() # start from a random position while True: hue += golden_ratio hue %= 1 - yield [int(255 * v) for v in colorsys.hsv_to_rgb(hue, 0.5, 0.95)] + yield tuple(int(255 * v) for v in colorsys.hsv_to_rgb(hue, 0.5, 0.95)) cgen = _color_generator() -def generate_color() -> List[int]: +def generate_color() -> Tuple[int, ...]: return next(cgen) diff --git a/tests/test_common/test_readme_examples.py b/tests/test_common/test_readme_examples.py index f4c89df..140b3ed 100644 --- a/tests/test_common/test_readme_examples.py +++ b/tests/test_common/test_readme_examples.py @@ -49,7 +49,7 @@ def test_code_blocks_in_readme(block_name: str, demo_sgy: Path, logs_dir_for_tes code = find_code_block(PROJECT_ROOT / "README.md", start_indicator, end_indicator) assert code - tmp_fname = "tmp.py" + tmp_fname = f"{block_name}.py" with open(tmp_fname, "w") as f: f.write(code) From 96aa52600ca13545a0503bcd246e78116a2954c0 Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Sun, 5 Oct 2025 18:30:26 +0200 Subject: [PATCH 4/9] no bench --- first_breaks/benchmark.py | 212 -------------------------------------- 1 file changed, 212 deletions(-) delete mode 100644 first_breaks/benchmark.py diff --git a/first_breaks/benchmark.py b/first_breaks/benchmark.py deleted file mode 100644 index a0abc53..0000000 --- a/first_breaks/benchmark.py +++ /dev/null @@ -1,212 +0,0 @@ -import json -from itertools import product -from os import system -from pathlib import Path -from typing import List, Optional, Union - -import numpy as np - -from first_breaks.const import FIRST_BYTE -from first_breaks.desktop.graph import export_image -from first_breaks.picking.picker_onnx import PickerONNX -from first_breaks.picking.picks import Picks -from first_breaks.picking.refiner import MinimalPhaseRefiner -from first_breaks.picking.task import Task -from first_breaks.sgy.reader import SGY -from first_breaks.utils.filtering import apply_savgol_filter -from first_breaks.utils.utils import as_list, calc_hash, download_and_validate_file - - -def download_model_with_heatmap(destination: Union[str, Path]) -> None: - model_hash = "afc03594f49b88ea61b5cf6ba8245be4" - model_url = "https://oml.daloroserver.com/download/seis/fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx" - download_and_validate_file(url=model_url, md5=model_hash, fname=destination) - - -def plot_picks_on_small_section_chunk(sgy: SGY, manual_picks: Picks, predicted_picks: Optional[Picks] = None) -> None: - limit_for_validation = 10 - val_image = "chunk.png" - val_sgy = SGY(source=sgy.read_traces_by_ids(list(range(limit_for_validation))), dt_mcs=sgy.dt_mcs) - val_manual_picks = Picks( - values=manual_picks.picks_in_mcs[:limit_for_validation], - unit="mcs", - dt_mcs=val_sgy.dt_mcs, - color=(255, 0, 0), - ) - picks = [val_manual_picks] - - if predicted_picks: - val_predicted_picks = Picks( - values=predicted_picks.picks_in_mcs[:limit_for_validation], - unit="mcs", - dt_mcs=val_sgy.dt_mcs, - color=(0, 0, 255), - ) - picks.append(val_predicted_picks) - - export_image( - source=val_sgy, - image_filename=val_image, - picks_list=picks, - height=1000, - width=1000, - ) - system(val_image) - - -def calc_snr10(traces: np.ndarray, picks: Picks, smooth: bool = False, symmetric: bool = True) -> List[float]: - if smooth: - traces = apply_savgol_filter(traces, polyorder=3, window_length=11, deriv=0) - - snr = np.ones(traces.shape[1]) - - for idx, pick in enumerate(picks.picks_in_samples): - if pick > 0: - noise = traces[:pick, idx] - if symmetric: - signal_and_noise = traces[pick : pick + len(noise), idx] - else: - signal_and_noise = traces[pick:, idx] - - p_noise = np.mean(np.square(noise)) - p_signal_and_noise = np.mean(np.square(signal_and_noise)) - snr[idx] = (p_signal_and_noise - p_noise) / p_noise - - snr10 = np.log10(snr) - snr10[np.isnan(snr10)] = -1000 - snr10[np.isinf(snr10)] = -2000 - snr10 = snr10.tolist() - - return snr10 - - -def benchmark( - sgy_filename: Union[str, Path], - model_filename: Union[str, Path], - report_filename: Union[str, Path], - gain_list: List[float], - maximum_time_list: List[float], - traces_per_gather_list: List[int], - saved_picks_byte_position: int, -): - sgy_filename = Path(sgy_filename).resolve() - assert sgy_filename.exists(), f"File {sgy_filename} not found" - sgy = SGY(source=sgy_filename) - print(f"SGY: {sgy_filename}; shape={sgy.shape}, dt_mcs={sgy.dt_mcs}") - - assert 1 <= saved_picks_byte_position <= 237 - saved_picks = Picks( - values=sgy.read_custom_trace_header(saved_picks_byte_position - FIRST_BYTE, "i"), - unit="mcs", - dt_mcs=sgy.dt_mcs, - ) - - plot_picks_on_small_section_chunk(sgy=sgy, manual_picks=saved_picks) - - download_model_with_heatmap(model_filename) - - report_filename = Path(report_filename) - report_filename.parent.mkdir(exist_ok=True, parents=True) - - picker = PickerONNX(model_path=model_filename, show_progressbar=True) - - to_export = {"confidence": [], "difference": [], "model_hash": picker.model_hash} - - total = len(gain_list) * len(maximum_time_list) * len(traces_per_gather_list) - for idx, (gain, maximum_time, tps) in enumerate(product(gain_list, maximum_time_list, traces_per_gather_list)): - task = Task( - source=sgy, - traces_per_gather=tps, - maximum_time=maximum_time, - gain=gain, - ) - print(f"Task {idx}/{total} started (gain={gain}, max_time={maximum_time}, tps={tps}) ...", flush=True) - task = picker.process_task(task) - predicted_picks = task.get_result() - - confidence = as_list(predicted_picks.confidence) - # difference between manual picks and predicted picks is anonymous and expose nothing, but allows me to compare - # performance with different parameters - difference_raw = ( - (np.array(saved_picks.picks_in_mcs) - np.array(predicted_picks.picks_in_mcs)).astype(int).tolist() - ) - - refined_picks = predicted_picks.create_duplicate() - refiner = MinimalPhaseRefiner() - refined_picks = refiner.refine(sgy=sgy, picks=refined_picks) - - difference_refined = ( - (np.array(saved_picks.picks_in_mcs) - np.array(refined_picks.picks_in_mcs)).astype(int).tolist() - ) - - to_export["confidence"].append( - {"gain": gain, "maximum_time": maximum_time, "traces_per_gather": tps, "values": confidence} - ) - to_export["difference"].append( - { - "gain": gain, - "maximum_time": maximum_time, - "traces_per_gather": tps, - "refined": False, - "values": difference_raw, - } - ) - to_export["difference"].append( - { - "gain": gain, - "maximum_time": maximum_time, - "traces_per_gather": tps, - "refined": True, - "values": difference_refined, - } - ) - - # FILE LEVEL STATS - - # hash of traces allows me to inderstand if reports were created based on same data or different - # without direct access to file: if 2 reports have same `traces_hash` it means that they were calculated based on - # same traces, if not - files were different. - # So I can understand how different parameters affect specific file analysing `difference` metric for several - # files belongs to same `traces_hash` - traces = sgy.read() - traces_hash = calc_hash(traces.tobytes(order="C")) - to_export["traces_hash"] = traces_hash - - # I would like to have anonymized base headers to better understand the number of seismic traces for each shot, - # and the number of shots. I want to try to automate the selection of parameter `traces_per_gather` based on this. - # I'm not interested in exact values of these headers, but rather in their distribution, so hashed values - # are sufficient. - for header in ["CHAN", "SOURCE", "FFID"]: - to_export[header] = sgy.traces_headers[header].apply(lambda x: calc_hash(str(x).encode())[:10]).tolist() - - to_export["shape"] = sgy.shape - to_export["dt_mcs"] = sgy.dt_mcs - - # I want to analyse how picking parameters and result correlate with SNR - to_export["SNR10"] = [] - for smooth, symmetric in product((True, False), (True, False)): - snr10 = calc_snr10(traces, saved_picks, smooth=smooth, symmetric=symmetric) - to_export["SNR10"].append({"smooth": smooth, "symmetric": symmetric, "values": snr10}) - - with open(report_filename, "w") as f: - json.dump(to_export, f) - - -if __name__ == "__main__": - sgy_filename_ = "my_data.sgy" - model_filename_ = "fb_heatmap_afc03594f49b88ea61b5cf6ba8245be4.onnx" - report_filename_ = "report.json" - gain_list_ = [0.1, 0.5, 1] - maximum_time_list_ = [100, 200] - traces_per_gather_list_ = [12] - saved_picks_byte_position_ = 237 - - benchmark( - sgy_filename=sgy_filename_, - model_filename=model_filename_, - report_filename=report_filename_, - gain_list=gain_list_, - maximum_time_list=maximum_time_list_, - traces_per_gather_list=traces_per_gather_list_, - saved_picks_byte_position=saved_picks_byte_position_, - ) From 05ac1c6e83da68a3367e76d85383b223e83546fe Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Sun, 5 Oct 2025 18:33:38 +0200 Subject: [PATCH 5/9] rm old visualizations --- first_breaks/utils/visualizations.py | 126 --------------------------- 1 file changed, 126 deletions(-) delete mode 100644 first_breaks/utils/visualizations.py diff --git a/first_breaks/utils/visualizations.py b/first_breaks/utils/visualizations.py deleted file mode 100644 index bf6574d..0000000 --- a/first_breaks/utils/visualizations.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import Optional, Tuple, Union - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.patches import Polygon - -from first_breaks.picking.utils import preprocess_gather - - -def plotseis( - data: np.ndarray, - picking: Optional[np.ndarray] = None, - add_picking: Optional[np.ndarray] = None, - normalizing: Optional[Union[str, int]] = "indiv", - clip: float = 0.9, - ampl: float = 1.0, - patch: bool = True, - colorseis: bool = False, - wiggle: bool = True, - background: Optional[np.ndarray] = None, - colorbar: bool = False, - dt: float = 1.0, - show: bool = True, - dpi: int = 300, - figsize: Tuple[int, int] = (5, 5), -) -> matplotlib.figure.Figure: - - num_time, num_trace = np.shape(data) - - # if normalizing == "indiv": - # norm_factor = np.mean(np.abs(data), axis=0) - # norm_factor[np.abs(norm_factor) < 1e-9 * np.max(np.abs(norm_factor))] = 1 - # elif normalizing == "entire": - # norm_factor = np.tile(np.mean(np.abs(data)), (1, num_trace)) - # elif np.size(normalizing) == 1 and normalizing is not None: - # norm_factor = np.tile(normalizing, (1, num_trace)) - # elif np.size(normalizing) == num_trace: - # norm_factor = np.reshape(normalizing, (1, num_trace)) - # elif normalizing is None: - # norm_factor = np.ones(data.shape[1]) - # else: - # raise ValueError('Wrong value of "normalizing"') - # - # data = data / norm_factor * ampl - - # mask_overflow = np.abs(data) > clip - # data[mask_overflow] = np.sign(data[mask_overflow]) * clip - - data = preprocess_gather(data=data, gain=ampl, clip=clip, normalize=normalizing) - - data_time = np.tile((np.arange(num_time) + 1)[:, np.newaxis], (1, num_trace)) * dt - - fig, ax = plt.subplots(figsize=figsize) - fig.set_dpi(dpi) - - plt.xlim((0, num_trace + 1)) - plt.ylim((0, num_time * dt)) - ax.invert_yaxis() - ax.xaxis.tick_top() - - if wiggle: - data_to_wiggle = data + (np.arange(num_trace) + 1)[np.newaxis, :] - - ax.plot(data_to_wiggle, data_time, color=(0, 0, 0)) - - if colorseis: - if not (wiggle or patch): - ax.imshow( - data, - aspect="auto", - interpolation="bilinear", - alpha=1, - extent=(1, num_trace, (num_time - 0.5) * dt, -0.5 * dt), - cmap="gray", - ) - else: - ax.imshow( - data, - aspect="auto", - interpolation="bilinear", - alpha=1, - extent=(-0.5, num_trace + 2 - 0.5, (num_time - 0.5) * dt, -0.5 * dt), - cmap="gray", - ) - - if patch: - data_to_patch = data - data_to_patch[data_to_patch < 0] = 0 - - for k_trace in range(num_trace): - patch_data = ( - (data_to_patch[:, k_trace] + k_trace + 1)[:, np.newaxis], - data_time[:, k_trace][:, np.newaxis], - ) - patch_data = np.hstack(patch_data) - - head = np.array((k_trace + 1, 0))[np.newaxis, :] - tail = np.array((k_trace + 1, num_time * dt))[np.newaxis, :] - patch_data = np.vstack((head, patch_data, tail)) - - polygon = Polygon(patch_data, closed=True, facecolor="black", edgecolor=None) - ax.add_patch(polygon) - - if picking is not None: - picking = np.array(picking) - ax.plot(np.arange(num_trace) + 1, picking * dt, linewidth=1, color="blue") - - if add_picking is not None: - add_picking = np.array(add_picking) - ax.plot(np.arange(num_trace) + 1, add_picking * dt, linewidth=1, color="green") - - if background is not None: - bg = ax.imshow( - background, - aspect="auto", - extent=(0.5, num_trace + 1 - 0.5, (num_time - 0.5) * dt, -0.5 * dt), - cmap="YlOrRd", - ) - - if colorbar: - plt.colorbar(mappable=bg) - - if show: - plt.show() - return fig From 1e350a686351391988677ff5c7698f04959bbdb3 Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Sun, 5 Oct 2025 18:43:44 +0200 Subject: [PATCH 6/9] pre-commits --- .pre-commit-config.yaml | 14 ++++++-------- first_breaks/_pytorch/picker_torch.py | 2 +- first_breaks/desktop/combobox_with_mapping.py | 2 +- first_breaks/desktop/main_gui.py | 5 ++--- first_breaks/desktop/picks_manager_widget.py | 2 +- first_breaks/desktop/radioset_widget.py | 2 +- first_breaks/picking/picker_onnx.py | 11 ++++++----- first_breaks/picking/picks.py | 2 +- first_breaks/picking/refiner.py | 17 +++++++++-------- first_breaks/utils/filtering.py | 2 +- pyproject.toml | 2 +- 11 files changed, 30 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7379794..ef4e9d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,10 @@ default_language_version: - python: python3.9 + python: python3.12 files: \.py$|tests/.*\.py$ -default_stages: -- commit repos: # general hooks to verify or beautify code - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v6.0.0 hooks: - id: check-added-large-files args: [--maxkb=5000] @@ -23,7 +21,7 @@ repos: # autoformat code with black formatter - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 25.9.0 hooks: - id: black files: first_breaks|tests @@ -32,7 +30,7 @@ repos: # beautify and sort imports - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 6.1.0 hooks: - id: isort files: first_breaks|tests @@ -41,7 +39,7 @@ repos: # check code style - repo: https://github.com/pycqa/flake8 - rev: 3.8.4 + rev: 7.3.0 hooks: - id: flake8 files: first_breaks|tests @@ -50,7 +48,7 @@ repos: # static type checking - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.5.1 + rev: v1.18.2 hooks: - id: mypy files: first_breaks|tests diff --git a/first_breaks/_pytorch/picker_torch.py b/first_breaks/_pytorch/picker_torch.py index 504b5d8..c12993b 100644 --- a/first_breaks/_pytorch/picker_torch.py +++ b/first_breaks/_pytorch/picker_torch.py @@ -80,7 +80,7 @@ def change_settings( # type: ignore device: Optional[str] = None, segmentation_hw: Optional[Tuple[int, int]] = None, num_workers: Optional[int] = None, - batch_size: Optional[int] = None + batch_size: Optional[int] = None, ) -> None: if args: raise ValueError("Use named arguments instead of positional") diff --git a/first_breaks/desktop/combobox_with_mapping.py b/first_breaks/desktop/combobox_with_mapping.py index 542da90..81cf200 100644 --- a/first_breaks/desktop/combobox_with_mapping.py +++ b/first_breaks/desktop/combobox_with_mapping.py @@ -30,7 +30,7 @@ def __init__( current_label: Optional[str] = None, current_value: Optional[str] = None, *args: Any, - **kwargs: Any + **kwargs: Any, ): super().__init__(*args, **kwargs) diff --git a/first_breaks/desktop/main_gui.py b/first_breaks/desktop/main_gui.py index 13a8d28..a81497a 100644 --- a/first_breaks/desktop/main_gui.py +++ b/first_breaks/desktop/main_gui.py @@ -1,7 +1,7 @@ import sys import warnings from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from PyQt5.QtCore import QSize, Qt, QThreadPool, pyqtSignal from PyQt5.QtGui import QCloseEvent @@ -22,7 +22,6 @@ from first_breaks.const import ( DEMO_SGY_PATH, HIGH_DPI, - MODEL_ONNX_HASH, MODEL_ONNX_HASHES, MODEL_ONNX_PATH, ) @@ -54,7 +53,7 @@ class FileState: file_changed = 2 @classmethod - def get_file_state(cls, fname: Union[str, Path], fhashes: str) -> int: + def get_file_state(cls, fname: Union[str, Path], fhashes: List[str]) -> int: if not Path(fname).is_file(): return cls.file_not_exists else: diff --git a/first_breaks/desktop/picks_manager_widget.py b/first_breaks/desktop/picks_manager_widget.py index 7a9f9dc..1e002c3 100644 --- a/first_breaks/desktop/picks_manager_widget.py +++ b/first_breaks/desktop/picks_manager_widget.py @@ -303,7 +303,7 @@ def __init__(self) -> None: self.remove_button = QPushButton("-", self) self.remove_button.clicked.connect(self.remove_items) # renamed for clarity # self.properties_button = QPushButton("\u2699", self) # Unicode character for gear - self.properties_button = QPushButton("\U0001F4BE", self) # "\U0001F4BE" or "\U0001F5AB" + self.properties_button = QPushButton("\U0001f4be", self) # "\U0001F4BE" or "\U0001F5AB" self.properties_button.setFont(self.font()) # to increase the size of the button a bit self.properties_button.clicked.connect(self.open_properties) button_layout.addWidget(self.add_button) diff --git a/first_breaks/desktop/radioset_widget.py b/first_breaks/desktop/radioset_widget.py index aa21ff9..e73e029 100644 --- a/first_breaks/desktop/radioset_widget.py +++ b/first_breaks/desktop/radioset_widget.py @@ -30,7 +30,7 @@ def __init__( orientation: str = "horizontal", margins: Optional[int] = None, *args: Any, - **kwargs: Any + **kwargs: Any, ): super().__init__(*args, **kwargs) assert orientation in ["horizontal", "vertical"] diff --git a/first_breaks/picking/picker_onnx.py b/first_breaks/picking/picker_onnx.py index e42bdc7..e80a68e 100644 --- a/first_breaks/picking/picker_onnx.py +++ b/first_breaks/picking/picker_onnx.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, Generator, Optional, Tuple, Union +from typing import Any, Dict, Generator, Optional, Union import numpy as np import onnxruntime as ort @@ -100,11 +100,13 @@ def __init__( self.model: Optional[ort.InferenceSession] = None self.init_model() - self._available_outputs = sorted(o.name for o in self.model.get_outputs()) - self._input_name = self.model.get_inputs()[0].name + self._available_outputs = sorted(o.name for o in self.model.get_outputs()) # type: ignore + self._input_name = self.model.get_inputs()[0].name # type: ignore for mandatory_key in [self.OUTPUT_PICKS_KEY, self.OUTPUT_CONFS_KEY]: - assert mandatory_key in self._available_outputs, f"`mandatory_key` not found in model outputs `{self._available_outputs}`" + assert ( + mandatory_key in self._available_outputs + ), f"`mandatory_key` not found in model outputs `{self._available_outputs}`" def is_heatmap_available(self) -> bool: return self.OUTPUT_HEATMAP_KEY in self._available_outputs @@ -164,7 +166,6 @@ def process_task(self, task: Task) -> Task: picks = results[self.OUTPUT_PICKS_KEY] confidence = results[self.OUTPUT_CONFS_KEY] - indices = batch["gather_ids"] task_picks_in_sample[indices.flatten()] = picks.flatten() # (B, W) -> (BxW) diff --git a/first_breaks/picking/picks.py b/first_breaks/picking/picks.py index 71a7c44..a5ab6e5 100644 --- a/first_breaks/picking/picks.py +++ b/first_breaks/picking/picks.py @@ -1,5 +1,5 @@ import uuid -from typing import List, Literal, Optional, Union, Annotated +from typing import Annotated, List, Literal, Optional, Union import numpy as np from pydantic import UUID4, Field, model_validator diff --git a/first_breaks/picking/refiner.py b/first_breaks/picking/refiner.py index 0cbc7df..23d7ca2 100644 --- a/first_breaks/picking/refiner.py +++ b/first_breaks/picking/refiner.py @@ -1,17 +1,14 @@ -from pprint import pprint -from typing import Tuple +from typing import Dict, Tuple import numpy as np -from first_breaks.const import PROJECT_ROOT from first_breaks.picking.picks import Picks from first_breaks.sgy.reader import SGY -from first_breaks.utils.debug import Performance from first_breaks.utils.filtering import apply_savgol_filter class Refiner: - def refine(self, sgy: SGY, picks: Picks): + def refine(self, sgy: SGY, picks: Picks) -> Picks: raise NotImplementedError @@ -40,7 +37,9 @@ def calc_intersection(data: np.ndarray, data_derivative: np.ndarray, tangent_poi return -intercept / slope -def calc_intersection_vectorized(data: np.ndarray, data_derivative: np.ndarray, extrema_mask: np.ndarray): +def calc_intersection_vectorized( + data: np.ndarray, data_derivative: np.ndarray, extrema_mask: np.ndarray +) -> Dict[int, Tuple[np.ndarray, ...]]: assert all(arr.ndim == 2 for arr in [data, data_derivative, extrema_mask]) assert extrema_mask.dtype == np.bool_ @@ -75,7 +74,7 @@ def get_band_mask( return row_indices_clipped, np.arange(num_cols) -def refine_picks( +def refine_picks( # type: ignore raw_picks: np.ndarray, probability_heatmap: np.ndarray, traces2intersections, @@ -145,7 +144,9 @@ def refine(self, sgy: SGY, picks: Picks) -> Picks: ) # previous intersections obtained based on band data. We need to map band intersections to initail coordintates - tr2intersections = {tr: band_mask[0][inter.round().astype(int), tr] for tr, inter in tr2intersections.items()} + tr2intersections = { + tr: band_mask[0][inter.round().astype(int), tr] for tr, inter in tr2intersections.items() # type: ignore + } refined_picks = refine_picks( raw_picks=picks_in_samples, diff --git a/first_breaks/utils/filtering.py b/first_breaks/utils/filtering.py index 588481a..bc7b2b1 100644 --- a/first_breaks/utils/filtering.py +++ b/first_breaks/utils/filtering.py @@ -31,7 +31,7 @@ def apply_savgol_filter(data: np.ndarray, window_length: int, polyorder: int, de if data.ndim == 1: padding = (half_window, half_window) else: - padding = ((half_window, half_window), (0, 0)) + padding = ((half_window, half_window), (0, 0)) # type: ignore padded_data = np.pad(data, padding, mode=pad_mode) diff --git a/pyproject.toml b/pyproject.toml index 5abd745..f7d4787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ entry-points = {console_scripts = {first-breaks-picking = "first_breaks.cli:cli_commands"}} dependencies = [ "requests>=2.28.2", - "numpy>=1.24.2", + "numpy (>=1.24.2,<2.0.0)", "pandas>=2.0.0", "PyQt5>=5.15.9", "pyqtgraph>=0.13.3", From f0e5d065565b17259e4c145077812fc943fd4c8c Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Sat, 11 Oct 2025 11:29:32 +0200 Subject: [PATCH 7/9] update CI --- .github/workflows/pre-commit-workflow.yaml | 13 ++++++------- .github/workflows/tests-workflow.yaml | 4 ++-- .github/workflows/tomls.yaml | 8 ++++---- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/.github/workflows/pre-commit-workflow.yaml b/.github/workflows/pre-commit-workflow.yaml index a484436..07d904d 100644 --- a/.github/workflows/pre-commit-workflow.yaml +++ b/.github/workflows/pre-commit-workflow.yaml @@ -12,18 +12,17 @@ on: jobs: pre_commit: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: - python-version: "3.9.0" + python-version: "3.12.0" - name: Install dependencies run: | - apt-get get update && apt-get install cmake - make install_precommit + pip install pre-commit - name: Pre-commit tests run: | - make run_precommit + pre-commit run -a diff --git a/.github/workflows/tests-workflow.yaml b/.github/workflows/tests-workflow.yaml index ca38326..fc24ebd 100644 --- a/.github/workflows/tests-workflow.yaml +++ b/.github/workflows/tests-workflow.yaml @@ -12,10 +12,10 @@ on: jobs: tests: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install package run: | make docker_build diff --git a/.github/workflows/tomls.yaml b/.github/workflows/tomls.yaml index ff1b372..2a3efa5 100644 --- a/.github/workflows/tomls.yaml +++ b/.github/workflows/tomls.yaml @@ -12,14 +12,14 @@ on: jobs: tests: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: - python-version: "3.8.0" + python-version: "3.12.0" - name: Install lib run: | pip install tomli From 70a5cf15e2fc820c2dd2e76061d4823c22f43528 Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Sat, 11 Oct 2025 13:49:53 +0200 Subject: [PATCH 8/9] update CI --- .github/workflows/tests-workflow.yaml | 12 +++++++++--- Dockerfile | 27 --------------------------- Makefile | 19 +------------------ pyproject.toml | 9 ++++----- pyproject_gpu.toml | 11 +++++------ 5 files changed, 19 insertions(+), 59 deletions(-) delete mode 100644 Dockerfile diff --git a/.github/workflows/tests-workflow.yaml b/.github/workflows/tests-workflow.yaml index fc24ebd..ba42a8c 100644 --- a/.github/workflows/tests-workflow.yaml +++ b/.github/workflows/tests-workflow.yaml @@ -16,9 +16,15 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - name: Install package + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12.0" + - name: Install dependencies run: | - make docker_build + sudo apt-get update + sudo apt-get install -y cmake + pip install -e . - name: Tests run: | - make docker_tests \ No newline at end of file + make run_tests \ No newline at end of file diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 7b31b1e..0000000 --- a/Dockerfile +++ /dev/null @@ -1,27 +0,0 @@ -FROM ubuntu:20.04 - -ENV DEBIAN_FRONTEND=noninteractive -ENV LIBGL_ALWAYS_INDIRECT=1 -ENV QT_QPA_PLATFORM=offscreen - -RUN apt-get -y update -RUN apt-get install -y git wget cmake -RUN apt-get install -y git libsm6 libxext6 libfontconfig1 libxrender1 libgl1-mesa-glx libglib2.0-0 libgtk2.0 qt5-qmake - -RUN apt install -y python3.8 python3.8-distutils python3-pip -RUN ln -s /usr/bin/python3.8 /usr/bin/python \ - && ln -sf /usr/bin/python3.8 /usr/bin/python3 - -ENV LD_LIBRARY_PATH "/usr/local/lib:$LD_LIBRARY_PATH" -ENV LC_ALL=C.UTF-8 -ENV LANG=C.UTF-8 - -RUN pip install --upgrade pip - -WORKDIR /first-breaks-picking -COPY pyproject.toml . -RUN pip install .[dependencies] - -COPY . /first-breaks-picking -RUN pip install -e . - diff --git a/Makefile b/Makefile index c983ec0..87a9dd3 100644 --- a/Makefile +++ b/Makefile @@ -1,28 +1,11 @@ -IMAGE_NAME ?= first-breaks-picking:latest - - -.PHONY: install_precommit -install_precommit: - python -m pip install --upgrade pre-commit==3.5.0 - .PHONY: run_precommit run_precommit: - pre-commit install && pre-commit run -a + pre-commit run -a .PHONY: run_tests run_tests: pytest -sv --disable-warnings tests - -.PHONY: docker_build -docker_build: - DOCKER_BUILDKIT=1 docker build -t $(IMAGE_NAME) . - -.PHONY: docker_tests -docker_tests: docker_build - docker run -t $(IMAGE_NAME) make run_tests - - .PHONY: build_wheel build_wheel: python -m pip install --upgrade pip diff --git a/pyproject.toml b/pyproject.toml index f7d4787..0b3959f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Project is devoted to pick waves that are the first to be detected on a seismogram with neural network" readme = {file = "README.md", content-type = "text/markdown"} -requires-python = ">=3.8" +requires-python = ">=3.12" authors = [ {name = "Aleksei Tarasov", email = "aleksei.v.tarasov@gmail.com"}, {name = "Aleksei Tarasov"}, @@ -33,10 +33,9 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Image Recognition", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ] entry-points = {console_scripts = {first-breaks-picking = "first_breaks.cli:cli_commands"}} dependencies = [ diff --git a/pyproject_gpu.toml b/pyproject_gpu.toml index 93b3d8b..369f79d 100644 --- a/pyproject_gpu.toml +++ b/pyproject_gpu.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Project is devoted to pick waves that are the first to be detected on a seismogram with neural network (CUDA accelerated)" readme = {file = "README.md", content-type = "text/markdown"} -requires-python = ">=3.8" +requires-python = ">=3.12" authors = [ {name = "Aleksei Tarasov", email = "aleksei.v.tarasov@gmail.com"}, {name = "Aleksei Tarasov"}, @@ -33,15 +33,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Image Recognition", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ] entry-points = {console_scripts = {first-breaks-picking = "first_breaks.cli:cli_commands"}} dependencies = [ "requests>=2.28.2", - "numpy>=1.24.2", + "numpy (>=1.24.2,<2.0.0)", "pandas>=2.0.0", "PyQt5>=5.15.9", "pyqtgraph>=0.13.3", From 1424b24271e65cc13d366d86533aa3a6f71c8fad Mon Sep 17 00:00:00 2001 From: DaloroAT Date: Sat, 11 Oct 2025 14:01:08 +0200 Subject: [PATCH 9/9] add qt deps --- .github/workflows/tests-workflow.yaml | 49 ++++++++++++++++----------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/.github/workflows/tests-workflow.yaml b/.github/workflows/tests-workflow.yaml index ba42a8c..ed3e3c2 100644 --- a/.github/workflows/tests-workflow.yaml +++ b/.github/workflows/tests-workflow.yaml @@ -1,30 +1,39 @@ name: Tests - on: pull_request: - branches: - - main + branches: [main] push: - branches: - - main - + branches: [main] jobs: tests: runs-on: ubuntu-24.04 + env: + QT_QPA_PLATFORM: offscreen + LIBGL_ALWAYS_INDIRECT: "1" steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12.0" - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y cmake - pip install -e . - - name: Tests - run: | - make run_tests \ No newline at end of file + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: System deps for PyQt5 (xcb/offscreen) + run: | + sudo apt-get update + sudo apt-get install -y \ + cmake \ + libgl1 libglib2.0-0 libsm6 libxext6 libxrender1 \ + libx11-6 libxcb1 libxkbcommon0 libxkbcommon-x11-0 \ + libfontconfig1 libfreetype6 + + - name: Install project + run: | + python -m pip install --upgrade pip + pip install -e . + + - name: Tests + run: make run_tests