From b4ae9c84925ba43dacc5e12478b94858847ea8b7 Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 16:04:18 -0700 Subject: [PATCH 1/9] Refactor align_stack into explicit run_pipeline entrypoint --- align_stack.py | 2347 ++++++++++++++++++++++++------------------------ 1 file changed, 1178 insertions(+), 1169 deletions(-) diff --git a/align_stack.py b/align_stack.py index d5638af..dae11d4 100644 --- a/align_stack.py +++ b/align_stack.py @@ -998,1225 +998,1234 @@ def compute_alignment_products( "stack_vec": stack_vec, } -# ===================== Channel / component selection ===================== -# User-facing components: Z, R, T -channels, process_as_three_comp, sel_comp_list = get_component_selection( - all_channels, component -) - - -# ===================== Main loop ===================== -# Storage for three-component mode -if process_as_three_comp: - all_component_data = {} - horizontal_window_cache = {} - horizontal_raw_limits_cache = {} - -for idx, channel in enumerate(channels): - sel_comp = sel_comp_list[idx] - - print(f"Processing channel: {channel}") - - for eve_id in events: - print(f"==========Processing event {eve_id}===========") - - # ---- Read event info ---- - event_depth, eve_lat, eve_lon, origin = load_event_metadata(eve_id, info_root) - save_dir = make_event_output_dir(path_prefix, eve_id) - name2ll = load_station_lookup(info_root) - - # ---- Read waveforms for all stations ---- - st_window, raw_limits_by_station = read_waveforms_for_event( - eve_id=eve_id, - channel=channel, - process_as_three_comp_mode=process_as_three_comp, - horizontal_window_cache=horizontal_window_cache, - horizontal_raw_limits_cache=horizontal_raw_limits_cache, - name2ll=name2ll, - eve_lat=eve_lat, - eve_lon=eve_lon, - origin=origin, - ) - if st_window is None: - continue - - # Label to show on the figure - plot_comp = sel_comp - - # ---- Rotate horizontal components to R/T ---- - if sel_comp in ("R", "T"): - print("Rotating horizontal components (N/E) to R/T ...") - - # In this dataset: DP1 is treated as N-like; DP2 is treated as E-like - stN = st_window.select(channel="DP1") - stE = st_window.select(channel="DP2") - rotated_traces = [] - - _rotate_wall_start = time.perf_counter() - _rotate_cpu_start = time.process_time() - for trN in stN: - sid = str(trN.stats.station) - stE_match = stE.select(station=sid) - if len(stE_match) == 0: - continue - trE = stE_match[0] - - # Back-azimuth (station -> event), plus instrument orientation correction (11°) - slat, slon = name2ll[sid] - _, _, baz_geo = gps2dist_azimuth(eve_lat, eve_lon, slat, slon) - baz = baz_geo - 11.0 - - # Synchronize length and rotate - npts_rot = min(trN.stats.npts, trE.stats.npts) - n = trN.data[:npts_rot] - e = trE.data[:npts_rot] - r, t = rotate_ne_rt(n, e, baz) - - if sel_comp == "R": - trR = trN.copy() - trR.data = r - trR.stats.channel = trN.stats.channel[:-1] + "R" - rotated_traces.append(trR) - plot_comp = "R" - elif sel_comp == "T": - trT = trN.copy() - trT.data = t - trT.stats.channel = trN.stats.channel[:-1] + "T" - rotated_traces.append(trT) - plot_comp = "T" - - add_stage_timing("rotate_to_rt", _rotate_wall_start, _rotate_cpu_start) - - st_comp = Stream(traces=rotated_traces) - else: - st_comp = st_window.select(channel=channel) - - # ---- Auto-select reference station: closest to array center ---- - ref_station_id, ref_trace = select_reference_trace(st_comp, name2ll) - if ref_trace is None: - continue - print_reference_summary(ref_station_id, ref_trace, raw_limits_by_station) - - # ---- Theoretical travel times (reference station) ---- - p_traveltime, s_traveltime, p_arrival_time, s_arrival_time, phase_traveltime = compute_phase_travel_times( - model, event_depth, ref_trace, origin, align_phase - ) - if phase_traveltime is None: - print(" No valid phase for alignment. Skip array.") - continue - # Theoretical arrival time (reference station) used for reference in shift calculations/plots - t_ref = phase_traveltime - - num_traces = len(st_comp) - print(f" {num_traces} traces on {plot_comp}") - if num_traces == 0: - continue - - # ---- Preprocess traces (detrend/taper/filter) ---- - _pre_wall_start = time.perf_counter() - _pre_cpu_start = time.process_time() - for tr in st_comp: - tr.detrend(type="demean") - trace_len_sec = float(tr.stats.npts) / float(tr.stats.sampling_rate) - taper_pct = min(0.05, 5.0 / trace_len_sec) if trace_len_sec > 0 else 0.0 - tr.taper(max_percentage=taper_pct, type="cosine") - tr.filter( - "bandpass", - freqmin=min_freq, - freqmax=max_freq, - corners=4, - zerophase=True, +def run_pipeline() -> None: + global align_phase, move_limit_sec, start_time, end_time + save_dir = Path(path_prefix + "output") + # ===================== Channel / component selection ===================== + # User-facing components: Z, R, T + channels, process_as_three_comp, sel_comp_list = get_component_selection( + all_channels, component + ) + + + # ===================== Main loop ===================== + # Storage for three-component mode + if process_as_three_comp: + all_component_data = {} + horizontal_window_cache = {} + horizontal_raw_limits_cache = {} + + for idx, channel in enumerate(channels): + sel_comp = sel_comp_list[idx] + + print(f"Processing channel: {channel}") + + for eve_id in events: + print(f"==========Processing event {eve_id}===========") + + # ---- Read event info ---- + event_depth, eve_lat, eve_lon, origin = load_event_metadata(eve_id, info_root) + save_dir = make_event_output_dir(path_prefix, eve_id) + name2ll = load_station_lookup(info_root) + + # ---- Read waveforms for all stations ---- + st_window, raw_limits_by_station = read_waveforms_for_event( + eve_id=eve_id, + channel=channel, + process_as_three_comp_mode=process_as_three_comp, + horizontal_window_cache=horizontal_window_cache, + horizontal_raw_limits_cache=horizontal_raw_limits_cache, + name2ll=name2ll, + eve_lat=eve_lat, + eve_lon=eve_lon, + origin=origin, ) - add_stage_timing("preprocess_filter", _pre_wall_start, _pre_cpu_start) - - alignment = compute_alignment_products( - st_comp=st_comp, - ref_trace=ref_trace, - ref_station_id=ref_station_id, - name2ll=name2ll, - eve_lat=eve_lat, - eve_lon=eve_lon, - event_depth=event_depth, - align_phase_name=align_phase, - t_ref=t_ref, - ) - npts = alignment["npts"] - sample_rate = alignment["sample_rate"] - move_limit_samples = alignment["move_limit_samples"] - win_start = alignment["win_start"] - win_end = alignment["win_end"] - calc_shifts = alignment["calc_shifts"] - aligned_stack = alignment["aligned_stack"] - selected_aligned_stack = alignment["selected_aligned_stack"] - selected_ids = alignment["selected_ids"] - station_corr = alignment["station_corr"] - n_pass_window = alignment["n_pass_window"] - pass_window_ids = alignment["pass_window_ids"] - snippet_by_station = alignment["snippet_by_station"] - ref_window = alignment["ref_window"] - selected_rows = alignment["selected_rows"] - rejected_rows = alignment["rejected_rows"] - station_shifts = alignment["station_shifts"] - aligned_traces_by_station = alignment["aligned_traces_by_station"] - t_abs = alignment["t_abs"] - mask = alignment["mask"] - stack_vec = alignment["stack_vec"] - - # ---- Plot: superposition of Stage1/Stage2/Final stacks ---- - _plot_wall_start = time.perf_counter() - _plot_cpu_start = time.process_time() - if not all_channels: - plot_stage_stacks( + if st_window is None: + continue + + # Label to show on the figure + plot_comp = sel_comp + + # ---- Rotate horizontal components to R/T ---- + if sel_comp in ("R", "T"): + print("Rotating horizontal components (N/E) to R/T ...") + + # In this dataset: DP1 is treated as N-like; DP2 is treated as E-like + stN = st_window.select(channel="DP1") + stE = st_window.select(channel="DP2") + rotated_traces = [] + + _rotate_wall_start = time.perf_counter() + _rotate_cpu_start = time.process_time() + for trN in stN: + sid = str(trN.stats.station) + stE_match = stE.select(station=sid) + if len(stE_match) == 0: + continue + trE = stE_match[0] + + # Back-azimuth (station -> event), plus instrument orientation correction (11°) + slat, slon = name2ll[sid] + _, _, baz_geo = gps2dist_azimuth(eve_lat, eve_lon, slat, slon) + baz = baz_geo - 11.0 + + # Synchronize length and rotate + npts_rot = min(trN.stats.npts, trE.stats.npts) + n = trN.data[:npts_rot] + e = trE.data[:npts_rot] + r, t = rotate_ne_rt(n, e, baz) + + if sel_comp == "R": + trR = trN.copy() + trR.data = r + trR.stats.channel = trN.stats.channel[:-1] + "R" + rotated_traces.append(trR) + plot_comp = "R" + elif sel_comp == "T": + trT = trN.copy() + trT.data = t + trT.stats.channel = trN.stats.channel[:-1] + "T" + rotated_traces.append(trT) + plot_comp = "T" + + add_stage_timing("rotate_to_rt", _rotate_wall_start, _rotate_cpu_start) + + st_comp = Stream(traces=rotated_traces) + else: + st_comp = st_window.select(channel=channel) + + # ---- Auto-select reference station: closest to array center ---- + ref_station_id, ref_trace = select_reference_trace(st_comp, name2ll) + if ref_trace is None: + continue + print_reference_summary(ref_station_id, ref_trace, raw_limits_by_station) + + # ---- Theoretical travel times (reference station) ---- + p_traveltime, s_traveltime, p_arrival_time, s_arrival_time, phase_traveltime = compute_phase_travel_times( + model, event_depth, ref_trace, origin, align_phase + ) + if phase_traveltime is None: + print(" No valid phase for alignment. Skip array.") + continue + # Theoretical arrival time (reference station) used for reference in shift calculations/plots + t_ref = phase_traveltime + + num_traces = len(st_comp) + print(f" {num_traces} traces on {plot_comp}") + if num_traces == 0: + continue + + # ---- Preprocess traces (detrend/taper/filter) ---- + _pre_wall_start = time.perf_counter() + _pre_cpu_start = time.process_time() + for tr in st_comp: + tr.detrend(type="demean") + trace_len_sec = float(tr.stats.npts) / float(tr.stats.sampling_rate) + taper_pct = min(0.05, 5.0 / trace_len_sec) if trace_len_sec > 0 else 0.0 + tr.taper(max_percentage=taper_pct, type="cosine") + tr.filter( + "bandpass", + freqmin=min_freq, + freqmax=max_freq, + corners=4, + zerophase=True, + ) + add_stage_timing("preprocess_filter", _pre_wall_start, _pre_cpu_start) + + alignment = compute_alignment_products( + st_comp=st_comp, + ref_trace=ref_trace, + ref_station_id=ref_station_id, + name2ll=name2ll, + eve_lat=eve_lat, + eve_lon=eve_lon, + event_depth=event_depth, + align_phase_name=align_phase, + t_ref=t_ref, + ) + npts = alignment["npts"] + sample_rate = alignment["sample_rate"] + move_limit_samples = alignment["move_limit_samples"] + win_start = alignment["win_start"] + win_end = alignment["win_end"] + calc_shifts = alignment["calc_shifts"] + aligned_stack = alignment["aligned_stack"] + selected_aligned_stack = alignment["selected_aligned_stack"] + selected_ids = alignment["selected_ids"] + station_corr = alignment["station_corr"] + n_pass_window = alignment["n_pass_window"] + pass_window_ids = alignment["pass_window_ids"] + snippet_by_station = alignment["snippet_by_station"] + ref_window = alignment["ref_window"] + selected_rows = alignment["selected_rows"] + rejected_rows = alignment["rejected_rows"] + station_shifts = alignment["station_shifts"] + aligned_traces_by_station = alignment["aligned_traces_by_station"] + t_abs = alignment["t_abs"] + mask = alignment["mask"] + stack_vec = alignment["stack_vec"] + + # ---- Plot: superposition of Stage1/Stage2/Final stacks ---- + _plot_wall_start = time.perf_counter() + _plot_cpu_start = time.process_time() + if not all_channels: + plot_stage_stacks( + eve_id=eve_id, + plot_comp=plot_comp, + align_phase_name=align_phase, + t_abs=t_abs, + mask=mask, + aligned_stack=aligned_stack, + selected_aligned_stack=selected_aligned_stack, + stack_vec=stack_vec, + save_dir=save_dir, + ) + + # ---- Plot: record section (top) + stack (bottom) ---- + record_fig = plot_record_section_and_stack( + show_record=show_record_section_plot, eve_id=eve_id, plot_comp=plot_comp, align_phase_name=align_phase, - t_abs=t_abs, - mask=mask, - aligned_stack=aligned_stack, - selected_aligned_stack=selected_aligned_stack, - stack_vec=stack_vec, - save_dir=save_dir, - ) - - # ---- Plot: record section (top) + stack (bottom) ---- - record_fig = plot_record_section_and_stack( - show_record=show_record_section_plot, - eve_id=eve_id, - plot_comp=plot_comp, - align_phase_name=align_phase, - selected_rows=selected_rows, - rejected_rows=rejected_rows, - t_abs=t_abs, - mask=mask, - sample_rate=sample_rate, - t_ref=t_ref, - win_start=win_start, - win_end=win_end, - move_sec=move_limit_sec, - npts=npts, - n_pass_window=n_pass_window, - stack_vec=stack_vec, - save_dir=save_dir, - ) - - # Store data for three-component plotting or show individual plot - if process_as_three_comp: - comp_key = resolve_component_key(channel, sel_comp) - all_component_data[comp_key] = build_component_output_payload( - record_fig=record_fig, selected_rows=selected_rows, rejected_rows=rejected_rows, - stack_vec=stack_vec, t_abs=t_abs, mask=mask, sample_rate=sample_rate, + t_ref=t_ref, win_start=win_start, win_end=win_end, - move_limit_sec_value=move_limit_sec, - move_limit_samples=move_limit_samples, + move_sec=move_limit_sec, npts=npts, - start_t=start_time, - end_t=end_time, - eve_id=eve_id, - align_phase_name=align_phase, - origin=origin, - station_shifts=station_shifts, - station_corr=station_corr, - calc_shifts=calc_shifts, n_pass_window=n_pass_window, - pass_window_ids=pass_window_ids, - snippet_by_station=snippet_by_station, - ref_window=ref_window, - p_traveltime=p_traveltime, - s_traveltime=s_traveltime, - name2ll=name2ll, - selected_ids=selected_ids, - aligned_traces_by_station=aligned_traces_by_station, - t_ref=t_ref, + stack_vec=stack_vec, + save_dir=save_dir, ) - - if record_fig is not None: - plt.close(record_fig) # Close individual figure - else: - # Show individual plot in non-three-component mode - # Save figure (same location/pattern as the original script) - save_file = save_dir / f"{eve_id}_{plot_comp}_{align_phase}.png" - if record_fig is not None: - record_fig.savefig(save_file, dpi=300, bbox_inches="tight") - - # ===================== Log10 envelope plot (single trace) ===================== - if num_traces == 1: + + # Store data for three-component plotting or show individual plot + if process_as_three_comp: + comp_key = resolve_component_key(channel, sel_comp) + all_component_data[comp_key] = build_component_output_payload( + record_fig=record_fig, + selected_rows=selected_rows, + rejected_rows=rejected_rows, + stack_vec=stack_vec, + t_abs=t_abs, + mask=mask, + sample_rate=sample_rate, + win_start=win_start, + win_end=win_end, + move_limit_sec_value=move_limit_sec, + move_limit_samples=move_limit_samples, + npts=npts, + start_t=start_time, + end_t=end_time, + eve_id=eve_id, + align_phase_name=align_phase, + origin=origin, + station_shifts=station_shifts, + station_corr=station_corr, + calc_shifts=calc_shifts, + n_pass_window=n_pass_window, + pass_window_ids=pass_window_ids, + snippet_by_station=snippet_by_station, + ref_window=ref_window, + p_traveltime=p_traveltime, + s_traveltime=s_traveltime, + name2ll=name2ll, + selected_ids=selected_ids, + aligned_traces_by_station=aligned_traces_by_station, + t_ref=t_ref, + ) + + if record_fig is not None: + plt.close(record_fig) # Close individual figure + else: + # Show individual plot in non-three-component mode + # Save figure (same location/pattern as the original script) + save_file = save_dir / f"{eve_id}_{plot_comp}_{align_phase}.png" + if record_fig is not None: + record_fig.savefig(save_file, dpi=300, bbox_inches="tight") + + # ===================== Log10 envelope plot (single trace) ===================== + if num_traces == 1: + try: + env = np.abs(hilbert(stack_vec)) + std_sec = 1.0 + std_samples = max(1.0, float(sample_rate) * std_sec) + win_samples = max(3, int(round(6.0 * std_samples))) + gauss = gaussian(win_samples, std_samples) + gauss = gauss / np.sum(gauss) + env_smooth = np.convolve(env, gauss, mode='same') + log_env = np.log10(np.maximum(env_smooth, 1e-12)) + + fig_env, ax_env = plt.subplots(figsize=(12, 4.5)) + set_figure_title(fig_env, f"{eve_id} {plot_comp} log10 envelope") + ax_env.plot(t_abs[mask], log_env[mask], color='k', lw=1.5) + ax_env.set_xlim(start_time, end_time) + ax_env.set_xlabel('Time since origin (s)', fontsize=11) + ax_env.set_ylabel('log10 envelope', fontsize=11) + ax_env.set_title( + f'Event {eve_id} - log10 envelope ({plot_comp})', + fontsize=12, + fontweight='bold', + ) + ax_env.grid(alpha=0.2) + add_catalog_event_lines(ax_env, origin, catalog_local, start_time, end_time) + fig_env.subplots_adjust(bottom=0.28) + + if origin is not None: + try: + add_utc_time_axis(ax_env, origin) + except Exception as e: + print(f"[WARN] Failed to add UTC time axis (single envelope): {e}") + + env_file = save_dir / f"{eve_id}_{plot_comp}_log10_envelope_{align_phase}.png" + fig_env.savefig(env_file, dpi=300, bbox_inches='tight') + print(f"✓ Log10 envelope plot saved to: {env_file}") + except Exception as e: + print(f"[WARN] Failed to create log10 envelope plot (single trace): {e}") + + # No Z-only R–T screening reuse. + # ===================== Estimated vs calculated shift plot (single component) ===================== + common_sta = set(calc_shifts.keys()) & set(station_shifts.keys()) + if len(common_sta) == 0: + print("[WARN] No stations with both estimated and calculated shifts for comparison.") + else: + stations = sorted(common_sta, key=lambda s: int(s)) + est_shift = np.array([station_shifts[s]['lag_seconds'] for s in stations], dtype=float) + calc_shift = np.array([calc_shifts[s] for s in stations], dtype=float) + + fig_ec, ax_ec = plt.subplots(1, 1, figsize=(6.2, 5.2)) + set_figure_title(fig_ec, f"{eve_id} {plot_comp} est vs calc shifts") + ax_ec.scatter(calc_shift, est_shift, s=20, alpha=0.6) + + minv = float(min(np.min(calc_shift), np.min(est_shift))) + maxv = float(max(np.max(calc_shift), np.max(est_shift))) + ax_ec.plot([minv, maxv], [minv, maxv], 'r--', lw=1.2, alpha=0.7, label='1:1 line') + + ax_ec.set_xlabel('Calculated shift (s)') + ax_ec.set_ylabel('Estimated shift (s)') + ax_ec.set_title(f"Event {eve_id} {plot_comp}: Estimated vs Calculated shifts") + ax_ec.grid(alpha=0.3) + ax_ec.legend(loc='upper left', fontsize=9) + plt.tight_layout() + + estcalc_file = save_dir / f"{eve_id}_{plot_comp}_est_vs_calc_shift_{align_phase}.png" + fig_ec.savefig(estcalc_file, dpi=300, bbox_inches='tight') + print(f"✓ Estimated vs calculated shift plot saved to: {estcalc_file}") + + # ===================== Snippet comparison plot (pass vs fail) ===================== try: - env = np.abs(hilbert(stack_vec)) - std_sec = 1.0 - std_samples = max(1.0, float(sample_rate) * std_sec) - win_samples = max(3, int(round(6.0 * std_samples))) - gauss = gaussian(win_samples, std_samples) - gauss = gauss / np.sum(gauss) - env_smooth = np.convolve(env, gauss, mode='same') - log_env = np.log10(np.maximum(env_smooth, 1e-12)) - - fig_env, ax_env = plt.subplots(figsize=(12, 4.5)) - set_figure_title(fig_env, f"{eve_id} {plot_comp} log10 envelope") - ax_env.plot(t_abs[mask], log_env[mask], color='k', lw=1.5) - ax_env.set_xlim(start_time, end_time) - ax_env.set_xlabel('Time since origin (s)', fontsize=11) - ax_env.set_ylabel('log10 envelope', fontsize=11) - ax_env.set_title( - f'Event {eve_id} - log10 envelope ({plot_comp})', + t_win = start_time + (np.arange(win_start, win_end) / sample_rate) + ref_win = ref_window + pass_list = sorted(list(pass_window_ids), key=lambda s: int(s)) + fail_list = sorted( + [s for s in snippet_by_station.keys() if s not in pass_window_ids], + key=lambda s: int(s), + ) + + n_show = 10 + pass_show = pass_list[:n_show] + fail_show = fail_list[:n_show] + + fig_snip, (axp, axf) = plt.subplots(1, 2, figsize=(10, 3.8), sharey=True) + set_figure_title(fig_snip, f"{eve_id} {plot_comp} correlation snippets") + + for sid in pass_show: + axp.plot(t_win, snippet_by_station[sid], color='k', alpha=0.4, lw=1) + axp.plot(t_win, ref_win, color='C3', lw=2, label='Ref window') + axp.set_title(f"Pass r_win (N={len(pass_list)})") + axp.set_xlabel('Time since origin (s)') + axp.grid(alpha=0.3) + + for sid in fail_show: + axf.plot(t_win, snippet_by_station[sid], color='k', alpha=0.4, lw=1) + axf.plot(t_win, ref_win, color='C3', lw=2, label='Ref window') + axf.set_title(f"Fail r_win (N={len(fail_list)})") + axf.set_xlabel('Time since origin (s)') + axf.grid(alpha=0.3) + + axp.set_ylabel('Normalized amplitude') + axf.legend(loc='upper right', fontsize=8) + fig_snip.suptitle( + f"Event {eve_id} {plot_comp}: correlation-window snippets", fontsize=12, fontweight='bold', ) - ax_env.grid(alpha=0.2) - add_catalog_event_lines(ax_env, origin, catalog_local, start_time, end_time) - fig_env.subplots_adjust(bottom=0.28) - - if origin is not None: - try: - add_utc_time_axis(ax_env, origin) - except Exception as e: - print(f"[WARN] Failed to add UTC time axis (single envelope): {e}") - - env_file = save_dir / f"{eve_id}_{plot_comp}_log10_envelope_{align_phase}.png" - fig_env.savefig(env_file, dpi=300, bbox_inches='tight') - print(f"✓ Log10 envelope plot saved to: {env_file}") + plt.tight_layout() + + snip_file = save_dir / f"{eve_id}_{plot_comp}_snippet_compare_{align_phase}.png" + fig_snip.savefig(snip_file, dpi=300, bbox_inches='tight') + print(f"✓ Snippet comparison plot saved to: {snip_file}") except Exception as e: - print(f"[WARN] Failed to create log10 envelope plot (single trace): {e}") - - # No Z-only R–T screening reuse. - # ===================== Estimated vs calculated shift plot (single component) ===================== - common_sta = set(calc_shifts.keys()) & set(station_shifts.keys()) - if len(common_sta) == 0: - print("[WARN] No stations with both estimated and calculated shifts for comparison.") - else: - stations = sorted(common_sta, key=lambda s: int(s)) - est_shift = np.array([station_shifts[s]['lag_seconds'] for s in stations], dtype=float) - calc_shift = np.array([calc_shifts[s] for s in stations], dtype=float) - - fig_ec, ax_ec = plt.subplots(1, 1, figsize=(6.2, 5.2)) - set_figure_title(fig_ec, f"{eve_id} {plot_comp} est vs calc shifts") - ax_ec.scatter(calc_shift, est_shift, s=20, alpha=0.6) - - minv = float(min(np.min(calc_shift), np.min(est_shift))) - maxv = float(max(np.max(calc_shift), np.max(est_shift))) - ax_ec.plot([minv, maxv], [minv, maxv], 'r--', lw=1.2, alpha=0.7, label='1:1 line') - - ax_ec.set_xlabel('Calculated shift (s)') - ax_ec.set_ylabel('Estimated shift (s)') - ax_ec.set_title(f"Event {eve_id} {plot_comp}: Estimated vs Calculated shifts") - ax_ec.grid(alpha=0.3) - ax_ec.legend(loc='upper left', fontsize=9) - plt.tight_layout() - - estcalc_file = save_dir / f"{eve_id}_{plot_comp}_est_vs_calc_shift_{align_phase}.png" - fig_ec.savefig(estcalc_file, dpi=300, bbox_inches='tight') - print(f"✓ Estimated vs calculated shift plot saved to: {estcalc_file}") - - # ===================== Snippet comparison plot (pass vs fail) ===================== - try: - t_win = start_time + (np.arange(win_start, win_end) / sample_rate) - ref_win = ref_window - pass_list = sorted(list(pass_window_ids), key=lambda s: int(s)) - fail_list = sorted( - [s for s in snippet_by_station.keys() if s not in pass_window_ids], - key=lambda s: int(s), - ) - - n_show = 10 - pass_show = pass_list[:n_show] - fail_show = fail_list[:n_show] - - fig_snip, (axp, axf) = plt.subplots(1, 2, figsize=(10, 3.8), sharey=True) - set_figure_title(fig_snip, f"{eve_id} {plot_comp} correlation snippets") - - for sid in pass_show: - axp.plot(t_win, snippet_by_station[sid], color='k', alpha=0.4, lw=1) - axp.plot(t_win, ref_win, color='C3', lw=2, label='Ref window') - axp.set_title(f"Pass r_win (N={len(pass_list)})") - axp.set_xlabel('Time since origin (s)') - axp.grid(alpha=0.3) - - for sid in fail_show: - axf.plot(t_win, snippet_by_station[sid], color='k', alpha=0.4, lw=1) - axf.plot(t_win, ref_win, color='C3', lw=2, label='Ref window') - axf.set_title(f"Fail r_win (N={len(fail_list)})") - axf.set_xlabel('Time since origin (s)') - axf.grid(alpha=0.3) - - axp.set_ylabel('Normalized amplitude') - axf.legend(loc='upper right', fontsize=8) - fig_snip.suptitle( - f"Event {eve_id} {plot_comp}: correlation-window snippets", - fontsize=12, - fontweight='bold', - ) - plt.tight_layout() - - snip_file = save_dir / f"{eve_id}_{plot_comp}_snippet_compare_{align_phase}.png" - fig_snip.savefig(snip_file, dpi=300, bbox_inches='tight') - print(f"✓ Snippet comparison plot saved to: {snip_file}") - except Exception as e: - print(f"[WARN] Failed to create snippet comparison plot: {e}") - - # ===================== Individual seismograms (20 traces per subplot, 5 panels per figure) ===================== - if show_individual_seismograms: - try: - all_rows = selected_rows + rejected_rows - all_rows.sort(key=lambda t: int(t[1])) - - n_traces = len(all_rows) - if n_traces > 0: - n_per = 20 - panels_per_fig = 5 - n_panels = int(np.ceil(n_traces / n_per)) - n_figs = int(np.ceil(n_panels / panels_per_fig)) - - for fig_idx in range(n_figs): - panel_start = fig_idx * panels_per_fig - panel_end = min((fig_idx + 1) * panels_per_fig, n_panels) - panels_in_fig = panel_end - panel_start - - fig_ind, axes_ind = plt.subplots( - panels_in_fig, - 1, - figsize=(10, 2.2 * panels_in_fig), - sharex=True, - sharey=False, - ) - set_figure_title( - fig_ind, - f"{eve_id} {plot_comp} individual seismograms fig {fig_idx + 1}", - ) - if panels_in_fig == 1: - axes_ind = [axes_ind] - - for p in range(panels_in_fig): - axp = axes_ind[p] - global_panel = panel_start + p - start_idx = global_panel * n_per - end_idx = min((global_panel + 1) * n_per, n_traces) - subset = all_rows[start_idx:end_idx] - - # Thresholding windows - t_win_start = start_time + (win_start / sample_rate) - t_win_end = start_time + (win_end / sample_rate) - t_explore_start = max(start_time, t_win_start - move_limit_sec) - t_explore_end = min(start_time + (npts / sample_rate), t_win_end + move_limit_sec) - axp.axvline(x=t_win_start, color='y', lw=1.2, alpha=0.9) - axp.axvline(x=t_win_end, color='y', lw=1.2, alpha=0.9) - axp.axvline(x=t_explore_start, color='g', lw=1.2, alpha=0.9) - axp.axvline(x=t_explore_end, color='g', lw=1.2, alpha=0.9) - - for idx_in_subset, (_, station_id, y) in enumerate(subset): - i = (len(subset) - 1) - idx_in_subset - passed_win = station_id in pass_window_ids - trace_color = 'k' if passed_win else 'red' + print(f"[WARN] Failed to create snippet comparison plot: {e}") + + # ===================== Individual seismograms (20 traces per subplot, 5 panels per figure) ===================== + if show_individual_seismograms: + try: + all_rows = selected_rows + rejected_rows + all_rows.sort(key=lambda t: int(t[1])) + + n_traces = len(all_rows) + if n_traces > 0: + n_per = 20 + panels_per_fig = 5 + n_panels = int(np.ceil(n_traces / n_per)) + n_figs = int(np.ceil(n_panels / panels_per_fig)) + + for fig_idx in range(n_figs): + panel_start = fig_idx * panels_per_fig + panel_end = min((fig_idx + 1) * panels_per_fig, n_panels) + panels_in_fig = panel_end - panel_start + + fig_ind, axes_ind = plt.subplots( + panels_in_fig, + 1, + figsize=(10, 2.2 * panels_in_fig), + sharex=True, + sharey=False, + ) + set_figure_title( + fig_ind, + f"{eve_id} {plot_comp} individual seismograms fig {fig_idx + 1}", + ) + if panels_in_fig == 1: + axes_ind = [axes_ind] + + for p in range(panels_in_fig): + axp = axes_ind[p] + global_panel = panel_start + p + start_idx = global_panel * n_per + end_idx = min((global_panel + 1) * n_per, n_traces) + subset = all_rows[start_idx:end_idx] + + # Thresholding windows + t_win_start = start_time + (win_start / sample_rate) + t_win_end = start_time + (win_end / sample_rate) + t_explore_start = max(start_time, t_win_start - move_limit_sec) + t_explore_end = min(start_time + (npts / sample_rate), t_win_end + move_limit_sec) + axp.axvline(x=t_win_start, color='y', lw=1.2, alpha=0.9) + axp.axvline(x=t_win_end, color='y', lw=1.2, alpha=0.9) + axp.axvline(x=t_explore_start, color='g', lw=1.2, alpha=0.9) + axp.axvline(x=t_explore_end, color='g', lw=1.2, alpha=0.9) + + for idx_in_subset, (_, station_id, y) in enumerate(subset): + i = (len(subset) - 1) - idx_in_subset + passed_win = station_id in pass_window_ids + trace_color = 'k' if passed_win else 'red' + axp.plot( + t_abs[mask], + y[mask] + i, + color=trace_color, + lw=0.7, + ) + axp.text( + t_abs[mask][0], + i, + station_id, + fontsize=6, + va='center', + ) + + # Reference stack above traces + ref_offset = len(subset) + 1 axp.plot( t_abs[mask], - y[mask] + i, - color=trace_color, - lw=0.7, - ) - axp.text( - t_abs[mask][0], - i, - station_id, - fontsize=6, - va='center', + stack_vec[mask] + ref_offset, + color='C3', + lw=1.2, ) - - # Reference stack above traces - ref_offset = len(subset) + 1 - axp.plot( - t_abs[mask], - stack_vec[mask] + ref_offset, - color='C3', - lw=1.2, + + axp.set_ylim(-1, len(subset) + 2) + axp.grid(alpha=0.2) + axp.set_ylabel('Trace index') + + axes_ind[-1].set_xlabel('Time since origin (s)') + fig_ind.suptitle( + f"Event {eve_id} {plot_comp}: individual seismograms " + f"(20 per panel, fig {fig_idx + 1}/{n_figs})", + fontsize=12, + fontweight='bold', ) - - axp.set_ylim(-1, len(subset) + 2) - axp.grid(alpha=0.2) - axp.set_ylabel('Trace index') - - axes_ind[-1].set_xlabel('Time since origin (s)') - fig_ind.suptitle( - f"Event {eve_id} {plot_comp}: individual seismograms " - f"(20 per panel, fig {fig_idx + 1}/{n_figs})", - fontsize=12, - fontweight='bold', - ) - plt.tight_layout() - - ind_file = save_dir / ( - f"{eve_id}_{plot_comp}_individual_seismograms_{align_phase}_fig{fig_idx + 1}.png" - ) - fig_ind.savefig(ind_file, dpi=300, bbox_inches='tight') - print(f"✓ Individual seismograms plot saved to: {ind_file}") + plt.tight_layout() + + ind_file = save_dir / ( + f"{eve_id}_{plot_comp}_individual_seismograms_{align_phase}_fig{fig_idx + 1}.png" + ) + fig_ind.savefig(ind_file, dpi=300, bbox_inches='tight') + print(f"✓ Individual seismograms plot saved to: {ind_file}") + except Exception as e: + print(f"[WARN] Failed to create individual seismograms plot: {e}") + + # ===================== Station maps: pass each threshold and both ===================== + try: + tr_map = aligned_traces_by_station + all_stations = sorted(tr_map.keys(), key=lambda s: int(s)) + pass_win = set(pass_window_ids) + + fig_map, axm = plt.subplots(1, 1, figsize=(6.5, 5.5)) + set_figure_title(fig_map, f"{eve_id} {plot_comp} station pass map") + pass_lats = [name2ll[s][0] for s in all_stations if s in pass_win] + pass_lons = [name2ll[s][1] for s in all_stations if s in pass_win] + fail_lats = [name2ll[s][0] for s in all_stations if s not in pass_win] + fail_lons = [name2ll[s][1] for s in all_stations if s not in pass_win] + + if len(fail_lons) > 0: + axm.scatter(fail_lons, fail_lats, s=18, c='0.7', label='Fail') + if len(pass_lons) > 0: + axm.scatter(pass_lons, pass_lats, s=22, c='C3', label='Pass') + + axm.set_title('Pass r_win', fontsize=11, fontweight='bold') + axm.grid(alpha=0.3) + axm.set_xlabel('Longitude') + axm.set_ylabel('Latitude') + axm.legend(loc='upper right', fontsize=9) + + fig_map.suptitle( + f"Event {eve_id} {plot_comp}: stations passing thresholds", + fontsize=13, + fontweight='bold', + ) + plt.tight_layout() + + map_file = save_dir / f"{eve_id}_{plot_comp}_station_pass_map_{align_phase}.png" + fig_map.savefig(map_file, dpi=300, bbox_inches='tight') + print(f"✓ Station pass/fail map saved to: {map_file}") except Exception as e: - print(f"[WARN] Failed to create individual seismograms plot: {e}") - - # ===================== Station maps: pass each threshold and both ===================== - try: - tr_map = aligned_traces_by_station - all_stations = sorted(tr_map.keys(), key=lambda s: int(s)) - pass_win = set(pass_window_ids) - - fig_map, axm = plt.subplots(1, 1, figsize=(6.5, 5.5)) - set_figure_title(fig_map, f"{eve_id} {plot_comp} station pass map") - pass_lats = [name2ll[s][0] for s in all_stations if s in pass_win] - pass_lons = [name2ll[s][1] for s in all_stations if s in pass_win] - fail_lats = [name2ll[s][0] for s in all_stations if s not in pass_win] - fail_lons = [name2ll[s][1] for s in all_stations if s not in pass_win] - - if len(fail_lons) > 0: - axm.scatter(fail_lons, fail_lats, s=18, c='0.7', label='Fail') - if len(pass_lons) > 0: - axm.scatter(pass_lons, pass_lats, s=22, c='C3', label='Pass') - - axm.set_title('Pass r_win', fontsize=11, fontweight='bold') - axm.grid(alpha=0.3) - axm.set_xlabel('Longitude') - axm.set_ylabel('Latitude') - axm.legend(loc='upper right', fontsize=9) - - fig_map.suptitle( - f"Event {eve_id} {plot_comp}: stations passing thresholds", - fontsize=13, - fontweight='bold', - ) - plt.tight_layout() - - map_file = save_dir / f"{eve_id}_{plot_comp}_station_pass_map_{align_phase}.png" - fig_map.savefig(map_file, dpi=300, bbox_inches='tight') - print(f"✓ Station pass/fail map saved to: {map_file}") - except Exception as e: - print(f"[WARN] Failed to create station pass/fail maps: {e}") - - add_stage_timing("plot_and_save", _plot_wall_start, _plot_cpu_start) - - # Show figures for single-component mode - report_timing_once() - plt.show() - - - -# ===================== Three-component combined plotting ===================== -if process_as_three_comp and len(all_component_data) == 3: - _plot3_wall_start = time.perf_counter() - _plot3_cpu_start = time.process_time() - print(f"\\n{'='*70}") - print(f"Creating combined three-component plot...") - print(f"{'='*70}\\n") - - fig = None - gs = None - if show_record_section_plot: - fig = plt.figure(figsize=(18, 9)) - set_figure_title(fig, f"{eve_id} {align_phase} 3-comp record section") - gs = fig.add_gridspec(2, 3, height_ratios=[3, 1], hspace=0.3, wspace=0.25) - - comp_order = ['DPZ', 'R', 'T'] - comp_titles = ['Vertical (Z)', 'Radial (R)', 'Transverse (T)'] - - # Get common parameters - first_data = all_component_data[comp_order[0]] - eve_id = first_data['eve_id'] - align_phase = first_data['align_phase'] - start_time = first_data['start_time'] - end_time = first_data['end_time'] - - # Pre-compute stations with zero R–T shift difference (for optional stacking) - zero_rt_diff_stations = None - stack_by_comp = {} - t_abs = first_data['t_abs'] - mask = first_data['mask'] - sample_rate_env = first_data['sample_rate'] - origin_env = first_data.get('origin') - - for idx, comp_name in enumerate(comp_order): - if comp_name not in all_component_data: - print(f"Warning: {comp_name} data not found") - continue - - data = all_component_data[comp_name] - all_rows = data['all_rows'] - stack_vec = data['stack_vec'] - t_abs = data['t_abs'] - mask = data['mask'] - sample_rate = data['sample_rate'] - win_start = data['win_start'] - win_end = data['win_end'] - move_limit_sec = data['move_limit_sec'] - npts = data['npts'] - t_ref = data.get('t_ref') - + print(f"[WARN] Failed to create station pass/fail maps: {e}") + + add_stage_timing("plot_and_save", _plot_wall_start, _plot_cpu_start) + + # Show figures for single-component mode + report_timing_once() + plt.show() + + + + # ===================== Three-component combined plotting ===================== + if process_as_three_comp and len(all_component_data) == 3: + _plot3_wall_start = time.perf_counter() + _plot3_cpu_start = time.process_time() + print(f"\\n{'='*70}") + print(f"Creating combined three-component plot...") + print(f"{'='*70}\\n") + + fig = None + gs = None if show_record_section_plot: - # Top panel: record section - ax = fig.add_subplot(gs[0, idx]) - - all_rows.sort(key=lambda t: t[0]) - t_masked = t_abs[mask] - - if len(all_rows) > 0 and np.any(mask): - A = np.vstack([row[2][mask] for row in all_rows]) - dvec = np.array([row[0] for row in all_rows], dtype=float) - - # y-edges - if len(dvec) == 1: - y_edges = np.array([dvec[0] - 0.5, dvec[0] + 0.5]) - else: - mids = 0.5 * (dvec[1:] + dvec[:-1]) - y_edges = np.empty(len(dvec) + 1) - y_edges[1:-1] = mids - y_edges[0] = dvec[0] - (mids[0] - dvec[0]) - y_edges[-1] = dvec[-1] + (dvec[-1] - mids[-1]) - - # t-edges - if len(t_masked) == 1: - t_edges = np.array([t_masked[0] - 0.5 / sample_rate, - t_masked[0] + 0.5 / sample_rate]) - else: - tmids = 0.5 * (t_masked[1:] + t_masked[:-1]) - t_edges = np.empty(len(t_masked) + 1) - t_edges[1:-1] = tmids - t_edges[0] = t_masked[0] - (tmids[0] - t_masked[0]) - t_edges[-1] = t_masked[-1] + (t_masked[-1] - tmids[-1]) - - ax.pcolormesh(t_edges, y_edges, A, cmap='gray', - shading='auto', vmin=-1.0, vmax=1.0) - - ax.set_xlim(start_time, end_time) - if idx == 0: - ax.set_ylabel('Epicentral distance (km)', fontsize=11) - ax.set_title(f'{comp_titles[idx]}', fontsize=12, fontweight='bold') - ax.grid(alpha=0.2) - - # Vertical reference line - if t_ref is not None: - ax.axvline(x=t_ref, color='r', lw=2, alpha=0.6, linestyle='--', zorder=6) - # Cross-correlation window bounds - try: - draw_correlation_markers( - ax, - start_time, - win_start, - win_end, - sample_rate, - move_limit_sec, - npts, - ) - except Exception as e: - print(f"[WARN] Failed to draw correlation window bounds (top {comp_name}): {e}") - - # Legend for window bounds (only once, top-left panel) - if idx == 0: + fig = plt.figure(figsize=(18, 9)) + set_figure_title(fig, f"{eve_id} {align_phase} 3-comp record section") + gs = fig.add_gridspec(2, 3, height_ratios=[3, 1], hspace=0.3, wspace=0.25) + + comp_order = ['DPZ', 'R', 'T'] + comp_titles = ['Vertical (Z)', 'Radial (R)', 'Transverse (T)'] + + # Get common parameters + first_data = all_component_data[comp_order[0]] + eve_id = first_data['eve_id'] + align_phase = first_data['align_phase'] + start_time = first_data['start_time'] + end_time = first_data['end_time'] + + # Pre-compute stations with zero R–T shift difference (for optional stacking) + zero_rt_diff_stations = None + stack_by_comp = {} + t_abs = first_data['t_abs'] + mask = first_data['mask'] + sample_rate_env = first_data['sample_rate'] + origin_env = first_data.get('origin') + + for idx, comp_name in enumerate(comp_order): + if comp_name not in all_component_data: + print(f"Warning: {comp_name} data not found") + continue + + data = all_component_data[comp_name] + all_rows = data['all_rows'] + stack_vec = data['stack_vec'] + t_abs = data['t_abs'] + mask = data['mask'] + sample_rate = data['sample_rate'] + win_start = data['win_start'] + win_end = data['win_end'] + move_limit_sec = data['move_limit_sec'] + npts = data['npts'] + t_ref = data.get('t_ref') + + if show_record_section_plot: + # Top panel: record section + ax = fig.add_subplot(gs[0, idx]) + + all_rows.sort(key=lambda t: t[0]) + t_masked = t_abs[mask] + + if len(all_rows) > 0 and np.any(mask): + A = np.vstack([row[2][mask] for row in all_rows]) + dvec = np.array([row[0] for row in all_rows], dtype=float) + + # y-edges + if len(dvec) == 1: + y_edges = np.array([dvec[0] - 0.5, dvec[0] + 0.5]) + else: + mids = 0.5 * (dvec[1:] + dvec[:-1]) + y_edges = np.empty(len(dvec) + 1) + y_edges[1:-1] = mids + y_edges[0] = dvec[0] - (mids[0] - dvec[0]) + y_edges[-1] = dvec[-1] + (dvec[-1] - mids[-1]) + + # t-edges + if len(t_masked) == 1: + t_edges = np.array([t_masked[0] - 0.5 / sample_rate, + t_masked[0] + 0.5 / sample_rate]) + else: + tmids = 0.5 * (t_masked[1:] + t_masked[:-1]) + t_edges = np.empty(len(t_masked) + 1) + t_edges[1:-1] = tmids + t_edges[0] = t_masked[0] - (tmids[0] - t_masked[0]) + t_edges[-1] = t_masked[-1] + (t_masked[-1] - tmids[-1]) + + ax.pcolormesh(t_edges, y_edges, A, cmap='gray', + shading='auto', vmin=-1.0, vmax=1.0) + + ax.set_xlim(start_time, end_time) + if idx == 0: + ax.set_ylabel('Epicentral distance (km)', fontsize=11) + ax.set_title(f'{comp_titles[idx]}', fontsize=12, fontweight='bold') + ax.grid(alpha=0.2) + + # Vertical reference line + if t_ref is not None: + ax.axvline(x=t_ref, color='r', lw=2, alpha=0.6, linestyle='--', zorder=6) + # Cross-correlation window bounds try: - n_pass_window = int(data.get('n_pass_window', 0)) - legend_handles = [ - Line2D([0], [0], color='y', lw=2, label='Correlation window'), - Line2D([0], [0], color='g', lw=2, label='Correlation search (±move_limit_sec)'), - Line2D([0], [0], color='none', label=f'Pass r_win: {n_pass_window}'), - ] - ax.legend( - handles=legend_handles, - loc='upper left', - bbox_to_anchor=(1.02, 1.0), - borderaxespad=0.0, - fontsize=9, + draw_correlation_markers( + ax, + start_time, + win_start, + win_end, + sample_rate, + move_limit_sec, + npts, ) except Exception as e: - print(f"[WARN] Failed to add legend (top {comp_name}): {e}") - - # Bottom panel: stack - ax2 = fig.add_subplot(gs[1, idx]) - ax2.plot(t_abs[mask], stack_vec[mask], color='C3', lw=2) - ax2.axhline(0.0, color='k', lw=0.6) - ax2.set_xlim(start_time, end_time) - ax2.set_xlabel('Time since origin (s)', fontsize=11) - if idx == 0: - ax2.set_ylabel('Stack (norm.)', fontsize=11) - ax2.set_ylim(-1.1, 1.1) - ax2.grid(alpha=0.2) - - if t_ref is not None: - ax2.axvline(x=t_ref, color='r', lw=2, alpha=0.6, linestyle='--', zorder=6) - # Cross-correlation window bounds - try: - draw_correlation_markers( - ax2, - start_time, - win_start, - win_end, - sample_rate, - move_limit_sec, - npts, - ) - except Exception as e: - print(f"[WARN] Failed to draw correlation window bounds (bottom {comp_name}): {e}") - - stack_by_comp[comp_name] = stack_vec - - if all(comp in stack_by_comp for comp in comp_order): - save_path = Path(path_prefix + "output") - save_dir = save_path / eve_id - save_dir.mkdir(parents=True, exist_ok=True) - - for comp_name in comp_order: - stack_vec = stack_by_comp[comp_name] - tr = Trace(data=stack_vec.astype(np.float32, copy=False)) - tr.stats.starttime = origin_env + start_time - tr.stats.sampling_rate = float(sample_rate_env) - tr.stats.station = "STACK" - tr.stats.channel = comp_name - st_out = Stream(traces=[tr]) - out_file = save_dir / f"{eve_id}_{comp_name}_stack.mseed" - st_out.write(str(out_file), format="MSEED") - print(f"✓ Wrote stack mseed: {out_file}") - - if show_record_section_plot: - fig.suptitle(f'Event {eve_id} - Aligned {align_phase} waveforms (3 components)', - fontsize=14, fontweight='bold') - - # Save combined figure - save_path = Path(path_prefix + "output") - save_dir = save_path / eve_id - save_dir.mkdir(parents=True, exist_ok=True) - save_file = save_dir / f"{eve_id}_3comp_{align_phase}.png" - fig.savefig(save_file, dpi=300, bbox_inches='tight') - print(f"\n✓ Three-component plot saved to: {save_file}") - print(f"\n✓ Three-component plot created successfully!\n") - # plt.show() # defer until end - - # ===================== Log10 envelope of 3-component stack ===================== - try: - if all(comp in stack_by_comp for comp in comp_order): - z = stack_by_comp['DPZ'] - r = stack_by_comp['R'] - t = stack_by_comp['T'] - env_z = np.abs(hilbert(z)) - env_r = np.abs(hilbert(r)) - env_t = np.abs(hilbert(t)) - env_rms = np.sqrt((env_z ** 2 + env_r ** 2 + env_t ** 2) / 3.0) - std_sec = 1.0 - std_samples = max(1.0, float(sample_rate_env) * std_sec) - win_samples = max(3, int(round(6.0 * std_samples))) - gauss = gaussian(win_samples, std_samples) - gauss = gauss / np.sum(gauss) - env_rms_smooth = np.convolve(env_rms, gauss, mode='same') - log_env = np.log10(np.maximum(env_rms_smooth, 1e-12)) - - fig_env, ax_env = plt.subplots(figsize=(12, 4.5)) - set_figure_title(fig_env, f"{eve_id} 3-comp log10 envelope") - ax_env.plot(t_abs[mask], log_env[mask], color='k', lw=1.5) - ax_env.set_xlim(start_time, end_time) - ax_env.set_xlabel('Time since origin (s)', fontsize=11) - ax_env.set_ylabel('log10 envelope', fontsize=11) - ax_env.set_title( - f'Event {eve_id} - log10 RMS envelope of 3-component stack', - fontsize=12, - fontweight='bold', - ) - ax_env.grid(alpha=0.2) - add_catalog_event_lines(ax_env, origin_env, catalog_local, start_time, end_time) - fig_env.subplots_adjust(bottom=0.28) - - if origin_env is not None: + print(f"[WARN] Failed to draw correlation window bounds (top {comp_name}): {e}") + + # Legend for window bounds (only once, top-left panel) + if idx == 0: + try: + n_pass_window = int(data.get('n_pass_window', 0)) + legend_handles = [ + Line2D([0], [0], color='y', lw=2, label='Correlation window'), + Line2D([0], [0], color='g', lw=2, label='Correlation search (±move_limit_sec)'), + Line2D([0], [0], color='none', label=f'Pass r_win: {n_pass_window}'), + ] + ax.legend( + handles=legend_handles, + loc='upper left', + bbox_to_anchor=(1.02, 1.0), + borderaxespad=0.0, + fontsize=9, + ) + except Exception as e: + print(f"[WARN] Failed to add legend (top {comp_name}): {e}") + + # Bottom panel: stack + ax2 = fig.add_subplot(gs[1, idx]) + ax2.plot(t_abs[mask], stack_vec[mask], color='C3', lw=2) + ax2.axhline(0.0, color='k', lw=0.6) + ax2.set_xlim(start_time, end_time) + ax2.set_xlabel('Time since origin (s)', fontsize=11) + if idx == 0: + ax2.set_ylabel('Stack (norm.)', fontsize=11) + ax2.set_ylim(-1.1, 1.1) + ax2.grid(alpha=0.2) + + if t_ref is not None: + ax2.axvline(x=t_ref, color='r', lw=2, alpha=0.6, linestyle='--', zorder=6) + # Cross-correlation window bounds try: - add_utc_time_axis(ax_env, origin_env) + draw_correlation_markers( + ax2, + start_time, + win_start, + win_end, + sample_rate, + move_limit_sec, + npts, + ) except Exception as e: - print(f"[WARN] Failed to add UTC time axis (envelope): {e}") - - env_file = save_dir / f"{eve_id}_3comp_log10_envelope_{align_phase}.png" - fig_env.savefig(env_file, dpi=300, bbox_inches='tight') - print(f"✓ Log10 envelope plot saved to: {env_file}") - except Exception as e: - print(f"[WARN] Failed to create log10 envelope plot: {e}") - - # No R–T zero-diff station list saved. - # ===================== Stack compare plot: all aligned vs r_min-selected ===================== - print("Creating stack comparison plot (all aligned vs r_min-selected)...") - - # Figure layout: 3 rows, 1 column (Z / R / T) — vertical arrangement - fig_cmp, axes_cmp = plt.subplots(3, 1, figsize=(9, 12), sharex=True, sharey=True) - set_figure_title(fig_cmp, f"{eve_id} stack compare") - comp_order = ['DPZ', 'R', 'T'] - comp_titles_cmp = ['Z stack', 'R stack', 'T stack'] - utc_tz = timezone.utc - - for j, comp_name in enumerate(comp_order): - axc = axes_cmp[j] - if comp_name not in all_component_data: - axc.set_axis_off() - continue - - data = all_component_data[comp_name] - t_abs = data['t_abs'] - mask = data['mask'] - start_time = data['start_time'] - end_time = data['end_time'] - p_time = data.get('p_traveltime') - s_time = data.get('s_traveltime') - - tr_map = data.get('aligned_traces_by_station', {}) - all_stations = sorted(tr_map.keys(), key=lambda s: int(s)) - - # Black: stack of all aligned traces - stack_black = np.zeros_like(t_abs) - if len(all_stations) > 0: - bank_all = [tr_map[sta] for sta in all_stations] - stack_black = np.mean(np.vstack(bank_all), axis=0) - ms = np.max(np.abs(stack_black)) or 1.0 - stack_black = stack_black / ms - - # Red: stack of traces that pass r_min thresholds (selected_ids) - sel_ids = data.get('selected_ids', []) - sel_ids = [s for s in sel_ids if s in tr_map] - n_pass_window = int(data.get('n_pass_window', len(sel_ids))) - stack_red = stack_black - if len(sel_ids) > 0: - bank_sel = [tr_map[sta] for sta in sel_ids] - stack_red = np.mean(np.vstack(bank_sel), axis=0) - ms = np.max(np.abs(stack_red)) or 1.0 - stack_red = stack_red / ms - - axc.plot(t_abs[mask], stack_black[mask], color='k', lw=2, label='All aligned traces') - axc.plot( - t_abs[mask], - stack_red[mask], - color='r', - lw=2, - label=f'Pass r_win N={n_pass_window}', - ) - axc.axhline(0.0, color='k', lw=0.6, alpha=0.6) - if p_time is not None: - axc.axvline(x=p_time, color='b', lw=1.5, alpha=0.7, linestyle='--', label='P arrival') - if s_time is not None: - axc.axvline(x=s_time, color='g', lw=1.5, alpha=0.7, linestyle='--', label='S arrival') - axc.set_xlim(start_time, end_time) - axc.set_ylim(-1.1, 1.1) - axc.grid(alpha=0.2) - axc.set_title(comp_titles_cmp[j], fontsize=12, fontweight='bold') - axc.set_xlabel('Time since origin (s)', fontsize=11) - if j != 2: - axc.set_xlabel('') - if j == 0: - axc.set_ylabel('Stack (norm.)', fontsize=11) - axc.legend(loc='upper right', fontsize=9) - - if j == 2: - try: - origin_utc = data.get('origin') - if origin_utc is not None: - add_utc_time_axis(axc, origin_utc, tick_tz=utc_tz) - except Exception as e: - print(f"[WARN] Failed to add UTC time axis: {e}") - - fig_cmp.suptitle( - f'Event {eve_id} - Stack compare (black: all aligned; red: pass r_min thresholds)', - fontsize=13, - fontweight='bold' - ) - plt.tight_layout() - - # Save comparison figure - cmp_file = save_dir / f"{eve_id}_rtfilter_stack_compare_{align_phase}.png" - fig_cmp.savefig(cmp_file, dpi=300, bbox_inches='tight') - print(f"✓ Stack comparison plot saved to: {cmp_file}") - # plt.show() - - # ===================== Individual seismograms (20 traces per subplot, 5 panels per figure, 3 components) ===================== - if show_individual_seismograms: + print(f"[WARN] Failed to draw correlation window bounds (bottom {comp_name}): {e}") + + stack_by_comp[comp_name] = stack_vec + + if all(comp in stack_by_comp for comp in comp_order): + save_path = Path(path_prefix + "output") + save_dir = save_path / eve_id + save_dir.mkdir(parents=True, exist_ok=True) + + for comp_name in comp_order: + stack_vec = stack_by_comp[comp_name] + tr = Trace(data=stack_vec.astype(np.float32, copy=False)) + tr.stats.starttime = origin_env + start_time + tr.stats.sampling_rate = float(sample_rate_env) + tr.stats.station = "STACK" + tr.stats.channel = comp_name + st_out = Stream(traces=[tr]) + out_file = save_dir / f"{eve_id}_{comp_name}_stack.mseed" + st_out.write(str(out_file), format="MSEED") + print(f"✓ Wrote stack mseed: {out_file}") + + if show_record_section_plot: + fig.suptitle(f'Event {eve_id} - Aligned {align_phase} waveforms (3 components)', + fontsize=14, fontweight='bold') + + # Save combined figure + save_path = Path(path_prefix + "output") + save_dir = save_path / eve_id + save_dir.mkdir(parents=True, exist_ok=True) + save_file = save_dir / f"{eve_id}_3comp_{align_phase}.png" + fig.savefig(save_file, dpi=300, bbox_inches='tight') + print(f"\n✓ Three-component plot saved to: {save_file}") + print(f"\n✓ Three-component plot created successfully!\n") + # plt.show() # defer until end + + # ===================== Log10 envelope of 3-component stack ===================== try: - for comp_name, comp_title in zip(['DPZ', 'R', 'T'], ['Z', 'R', 'T']): - if comp_name not in all_component_data: - continue - - data = all_component_data[comp_name] - all_rows = data.get('all_rows', []) - all_rows = sorted(all_rows, key=lambda t: int(t[1])) - t_abs = data['t_abs'] - mask = data['mask'] - sample_rate = data['sample_rate'] - win_start = data['win_start'] - win_end = data['win_end'] - move_limit_sec = data['move_limit_sec'] - npts = data['npts'] - - n_traces = len(all_rows) - if n_traces == 0: - continue - - n_per = 20 - panels_per_fig = 5 - n_panels = int(np.ceil(n_traces / n_per)) - n_figs = int(np.ceil(n_panels / panels_per_fig)) - - for fig_idx in range(n_figs): - panel_start = fig_idx * panels_per_fig - panel_end = min((fig_idx + 1) * panels_per_fig, n_panels) - panels_in_fig = panel_end - panel_start - - fig_ind, axes_ind = plt.subplots( - panels_in_fig, - 1, - figsize=(10, 2.2 * panels_in_fig), - sharex=True, - sharey=False, - ) - set_figure_title( - fig_ind, - f"{eve_id} {comp_title} individual seismograms fig {fig_idx + 1}", - ) - if panels_in_fig == 1: - axes_ind = [axes_ind] - - for p in range(panels_in_fig): - axp = axes_ind[p] - global_panel = panel_start + p - start_idx = global_panel * n_per - end_idx = min((global_panel + 1) * n_per, n_traces) - subset = all_rows[start_idx:end_idx] - - # Thresholding windows - t_win_start = start_time + (win_start / sample_rate) - t_win_end = start_time + (win_end / sample_rate) - t_explore_start = max(start_time, t_win_start - move_limit_sec) - t_explore_end = min(start_time + (npts / sample_rate), t_win_end + move_limit_sec) - axp.axvline(x=t_win_start, color='y', lw=1.2, alpha=0.9) - axp.axvline(x=t_win_end, color='y', lw=1.2, alpha=0.9) - axp.axvline(x=t_explore_start, color='g', lw=1.2, alpha=0.9) - axp.axvline(x=t_explore_end, color='g', lw=1.2, alpha=0.9) - - for idx_in_subset, (_, station_id, y) in enumerate(subset): - i = (len(subset) - 1) - idx_in_subset - passed_win = station_id in pass_window_ids - trace_color = 'k' if passed_win else 'red' - axp.plot( - t_abs[mask], - y[mask] + i, - color=trace_color, - lw=0.7, - ) - axp.text( - t_abs[mask][0], - i, - station_id, - fontsize=6, - va='center', - ) - - # Reference stack above traces - ref_offset = len(subset) + 1 - stack_ref = data.get('stack_vec', None) - if stack_ref is not None: - axp.plot( - t_abs[mask], - stack_ref[mask] + ref_offset, - color='C3', - lw=1.2, - ) - - axp.set_ylim(-1, len(subset) + 2) - axp.grid(alpha=0.2) - axp.set_ylabel('Trace index') - - axes_ind[-1].set_xlabel('Time since origin (s)') - fig_ind.suptitle( - f"Event {eve_id} {comp_title}: individual seismograms " - f"(20 per panel, fig {fig_idx + 1}/{n_figs})", - fontsize=12, - fontweight='bold', - ) - plt.tight_layout() - - ind_file = save_dir / ( - f"{eve_id}_{comp_title}_individual_seismograms_{align_phase}_fig{fig_idx + 1}.png" - ) - fig_ind.savefig(ind_file, dpi=300, bbox_inches='tight') - print(f"✓ Individual seismograms plot saved to: {ind_file}") + if all(comp in stack_by_comp for comp in comp_order): + z = stack_by_comp['DPZ'] + r = stack_by_comp['R'] + t = stack_by_comp['T'] + env_z = np.abs(hilbert(z)) + env_r = np.abs(hilbert(r)) + env_t = np.abs(hilbert(t)) + env_rms = np.sqrt((env_z ** 2 + env_r ** 2 + env_t ** 2) / 3.0) + std_sec = 1.0 + std_samples = max(1.0, float(sample_rate_env) * std_sec) + win_samples = max(3, int(round(6.0 * std_samples))) + gauss = gaussian(win_samples, std_samples) + gauss = gauss / np.sum(gauss) + env_rms_smooth = np.convolve(env_rms, gauss, mode='same') + log_env = np.log10(np.maximum(env_rms_smooth, 1e-12)) + + fig_env, ax_env = plt.subplots(figsize=(12, 4.5)) + set_figure_title(fig_env, f"{eve_id} 3-comp log10 envelope") + ax_env.plot(t_abs[mask], log_env[mask], color='k', lw=1.5) + ax_env.set_xlim(start_time, end_time) + ax_env.set_xlabel('Time since origin (s)', fontsize=11) + ax_env.set_ylabel('log10 envelope', fontsize=11) + ax_env.set_title( + f'Event {eve_id} - log10 RMS envelope of 3-component stack', + fontsize=12, + fontweight='bold', + ) + ax_env.grid(alpha=0.2) + add_catalog_event_lines(ax_env, origin_env, catalog_local, start_time, end_time) + fig_env.subplots_adjust(bottom=0.28) + + if origin_env is not None: + try: + add_utc_time_axis(ax_env, origin_env) + except Exception as e: + print(f"[WARN] Failed to add UTC time axis (envelope): {e}") + + env_file = save_dir / f"{eve_id}_3comp_log10_envelope_{align_phase}.png" + fig_env.savefig(env_file, dpi=300, bbox_inches='tight') + print(f"✓ Log10 envelope plot saved to: {env_file}") except Exception as e: - print(f"[WARN] Failed to create individual seismograms plots (3 components): {e}") - - # ===================== Station maps: pass r_win (3 components) ===================== - try: - fig_map, axes_map = plt.subplots(1, 3, figsize=(14, 4.5), sharex=True, sharey=True) - set_figure_title(fig_map, f"{eve_id} station pass map (3-comp)") + print(f"[WARN] Failed to create log10 envelope plot: {e}") + + # No R–T zero-diff station list saved. + # ===================== Stack compare plot: all aligned vs r_min-selected ===================== + print("Creating stack comparison plot (all aligned vs r_min-selected)...") + + # Figure layout: 3 rows, 1 column (Z / R / T) — vertical arrangement + fig_cmp, axes_cmp = plt.subplots(3, 1, figsize=(9, 12), sharex=True, sharey=True) + set_figure_title(fig_cmp, f"{eve_id} stack compare") comp_order = ['DPZ', 'R', 'T'] - comp_titles_map = ['Z', 'R', 'T'] - + comp_titles_cmp = ['Z stack', 'R stack', 'T stack'] + utc_tz = timezone.utc + for j, comp_name in enumerate(comp_order): - axm = axes_map[j] + axc = axes_cmp[j] if comp_name not in all_component_data: - axm.set_axis_off() + axc.set_axis_off() continue - + data = all_component_data[comp_name] - station_ll = data.get('station_ll', {}) + t_abs = data['t_abs'] + mask = data['mask'] + start_time = data['start_time'] + end_time = data['end_time'] + p_time = data.get('p_traveltime') + s_time = data.get('s_traveltime') + tr_map = data.get('aligned_traces_by_station', {}) all_stations = sorted(tr_map.keys(), key=lambda s: int(s)) - pass_set = set(data.get('pass_window_ids', [])) - - pass_lats = [station_ll[s][0] for s in all_stations if s in pass_set and s in station_ll] - pass_lons = [station_ll[s][1] for s in all_stations if s in pass_set and s in station_ll] - fail_lats = [station_ll[s][0] for s in all_stations if s not in pass_set and s in station_ll] - fail_lons = [station_ll[s][1] for s in all_stations if s not in pass_set and s in station_ll] - - if len(fail_lons) > 0: - axm.scatter(fail_lons, fail_lats, s=16, c='0.7', label='Fail') - if len(pass_lons) > 0: - axm.scatter(pass_lons, pass_lats, s=20, c='C3', label='Pass') - - axm.set_title(comp_titles_map[j], fontsize=12, fontweight='bold') + + # Black: stack of all aligned traces + stack_black = np.zeros_like(t_abs) + if len(all_stations) > 0: + bank_all = [tr_map[sta] for sta in all_stations] + stack_black = np.mean(np.vstack(bank_all), axis=0) + ms = np.max(np.abs(stack_black)) or 1.0 + stack_black = stack_black / ms + + # Red: stack of traces that pass r_min thresholds (selected_ids) + sel_ids = data.get('selected_ids', []) + sel_ids = [s for s in sel_ids if s in tr_map] + n_pass_window = int(data.get('n_pass_window', len(sel_ids))) + stack_red = stack_black + if len(sel_ids) > 0: + bank_sel = [tr_map[sta] for sta in sel_ids] + stack_red = np.mean(np.vstack(bank_sel), axis=0) + ms = np.max(np.abs(stack_red)) or 1.0 + stack_red = stack_red / ms + + axc.plot(t_abs[mask], stack_black[mask], color='k', lw=2, label='All aligned traces') + axc.plot( + t_abs[mask], + stack_red[mask], + color='r', + lw=2, + label=f'Pass r_win N={n_pass_window}', + ) + axc.axhline(0.0, color='k', lw=0.6, alpha=0.6) + if p_time is not None: + axc.axvline(x=p_time, color='b', lw=1.5, alpha=0.7, linestyle='--', label='P arrival') + if s_time is not None: + axc.axvline(x=s_time, color='g', lw=1.5, alpha=0.7, linestyle='--', label='S arrival') + axc.set_xlim(start_time, end_time) + axc.set_ylim(-1.1, 1.1) + axc.grid(alpha=0.2) + axc.set_title(comp_titles_cmp[j], fontsize=12, fontweight='bold') + axc.set_xlabel('Time since origin (s)', fontsize=11) + if j != 2: + axc.set_xlabel('') if j == 0: - axm.set_ylabel('Latitude') - axm.grid(alpha=0.3) - axm.set_xlabel('Longitude') - axm.legend(loc='upper right', fontsize=8) - - fig_map.suptitle( - f'Event {eve_id} - Stations passing thresholds ({align_phase})', + axc.set_ylabel('Stack (norm.)', fontsize=11) + axc.legend(loc='upper right', fontsize=9) + + if j == 2: + try: + origin_utc = data.get('origin') + if origin_utc is not None: + add_utc_time_axis(axc, origin_utc, tick_tz=utc_tz) + except Exception as e: + print(f"[WARN] Failed to add UTC time axis: {e}") + + fig_cmp.suptitle( + f'Event {eve_id} - Stack compare (black: all aligned; red: pass r_min thresholds)', fontsize=13, fontweight='bold' ) plt.tight_layout() - - map_file = save_dir / f"{eve_id}_station_pass_map_{align_phase}.png" - fig_map.savefig(map_file, dpi=300, bbox_inches='tight') - print(f"✓ Station pass/fail map saved to: {map_file}") - except Exception as e: - print(f"[WARN] Failed to create station pass/fail map (3 components): {e}") - - # ===================== Shift comparison plot: Radial vs Transverse ===================== - if 'R' in all_component_data and 'T' in all_component_data: - print("Creating shift comparison plot (Radial vs Transverse)...") - print( - "Shift comparison parameters: " - f"align_phase={align_phase}, start_time={start_time}, end_time={end_time}, " - f"win_pre={win_pre}, win_post={win_post}, " - f"move_limit_sec={move_limit_sec}" - ) - - r_shifts = all_component_data['R']['station_shifts'] - t_shifts = all_component_data['T']['station_shifts'] - r_corr = all_component_data['R']['station_corr'] - t_corr = all_component_data['T']['station_corr'] - r_calc = all_component_data['R'].get('calc_shifts', {}) - t_calc = all_component_data['T'].get('calc_shifts', {}) + + # Save comparison figure + cmp_file = save_dir / f"{eve_id}_rtfilter_stack_compare_{align_phase}.png" + fig_cmp.savefig(cmp_file, dpi=300, bbox_inches='tight') + print(f"✓ Stack comparison plot saved to: {cmp_file}") + # plt.show() + + # ===================== Individual seismograms (20 traces per subplot, 5 panels per figure, 3 components) ===================== + if show_individual_seismograms: + try: + for comp_name, comp_title in zip(['DPZ', 'R', 'T'], ['Z', 'R', 'T']): + if comp_name not in all_component_data: + continue + + data = all_component_data[comp_name] + all_rows = data.get('all_rows', []) + all_rows = sorted(all_rows, key=lambda t: int(t[1])) + t_abs = data['t_abs'] + mask = data['mask'] + sample_rate = data['sample_rate'] + win_start = data['win_start'] + win_end = data['win_end'] + move_limit_sec = data['move_limit_sec'] + npts = data['npts'] + + n_traces = len(all_rows) + if n_traces == 0: + continue + + n_per = 20 + panels_per_fig = 5 + n_panels = int(np.ceil(n_traces / n_per)) + n_figs = int(np.ceil(n_panels / panels_per_fig)) + + for fig_idx in range(n_figs): + panel_start = fig_idx * panels_per_fig + panel_end = min((fig_idx + 1) * panels_per_fig, n_panels) + panels_in_fig = panel_end - panel_start + + fig_ind, axes_ind = plt.subplots( + panels_in_fig, + 1, + figsize=(10, 2.2 * panels_in_fig), + sharex=True, + sharey=False, + ) + set_figure_title( + fig_ind, + f"{eve_id} {comp_title} individual seismograms fig {fig_idx + 1}", + ) + if panels_in_fig == 1: + axes_ind = [axes_ind] + + for p in range(panels_in_fig): + axp = axes_ind[p] + global_panel = panel_start + p + start_idx = global_panel * n_per + end_idx = min((global_panel + 1) * n_per, n_traces) + subset = all_rows[start_idx:end_idx] + + # Thresholding windows + t_win_start = start_time + (win_start / sample_rate) + t_win_end = start_time + (win_end / sample_rate) + t_explore_start = max(start_time, t_win_start - move_limit_sec) + t_explore_end = min(start_time + (npts / sample_rate), t_win_end + move_limit_sec) + axp.axvline(x=t_win_start, color='y', lw=1.2, alpha=0.9) + axp.axvline(x=t_win_end, color='y', lw=1.2, alpha=0.9) + axp.axvline(x=t_explore_start, color='g', lw=1.2, alpha=0.9) + axp.axvline(x=t_explore_end, color='g', lw=1.2, alpha=0.9) + + for idx_in_subset, (_, station_id, y) in enumerate(subset): + i = (len(subset) - 1) - idx_in_subset + passed_win = station_id in pass_window_ids + trace_color = 'k' if passed_win else 'red' + axp.plot( + t_abs[mask], + y[mask] + i, + color=trace_color, + lw=0.7, + ) + axp.text( + t_abs[mask][0], + i, + station_id, + fontsize=6, + va='center', + ) + + # Reference stack above traces + ref_offset = len(subset) + 1 + stack_ref = data.get('stack_vec', None) + if stack_ref is not None: + axp.plot( + t_abs[mask], + stack_ref[mask] + ref_offset, + color='C3', + lw=1.2, + ) + + axp.set_ylim(-1, len(subset) + 2) + axp.grid(alpha=0.2) + axp.set_ylabel('Trace index') + + axes_ind[-1].set_xlabel('Time since origin (s)') + fig_ind.suptitle( + f"Event {eve_id} {comp_title}: individual seismograms " + f"(20 per panel, fig {fig_idx + 1}/{n_figs})", + fontsize=12, + fontweight='bold', + ) + plt.tight_layout() + + ind_file = save_dir / ( + f"{eve_id}_{comp_title}_individual_seismograms_{align_phase}_fig{fig_idx + 1}.png" + ) + fig_ind.savefig(ind_file, dpi=300, bbox_inches='tight') + print(f"✓ Individual seismograms plot saved to: {ind_file}") + except Exception as e: + print(f"[WARN] Failed to create individual seismograms plots (3 components): {e}") + + # ===================== Station maps: pass r_win (3 components) ===================== + try: + fig_map, axes_map = plt.subplots(1, 3, figsize=(14, 4.5), sharex=True, sharey=True) + set_figure_title(fig_map, f"{eve_id} station pass map (3-comp)") + comp_order = ['DPZ', 'R', 'T'] + comp_titles_map = ['Z', 'R', 'T'] + + for j, comp_name in enumerate(comp_order): + axm = axes_map[j] + if comp_name not in all_component_data: + axm.set_axis_off() + continue + + data = all_component_data[comp_name] + station_ll = data.get('station_ll', {}) + tr_map = data.get('aligned_traces_by_station', {}) + all_stations = sorted(tr_map.keys(), key=lambda s: int(s)) + pass_set = set(data.get('pass_window_ids', [])) + + pass_lats = [station_ll[s][0] for s in all_stations if s in pass_set and s in station_ll] + pass_lons = [station_ll[s][1] for s in all_stations if s in pass_set and s in station_ll] + fail_lats = [station_ll[s][0] for s in all_stations if s not in pass_set and s in station_ll] + fail_lons = [station_ll[s][1] for s in all_stations if s not in pass_set and s in station_ll] + + if len(fail_lons) > 0: + axm.scatter(fail_lons, fail_lats, s=16, c='0.7', label='Fail') + if len(pass_lons) > 0: + axm.scatter(pass_lons, pass_lats, s=20, c='C3', label='Pass') + + axm.set_title(comp_titles_map[j], fontsize=12, fontweight='bold') + if j == 0: + axm.set_ylabel('Latitude') + axm.grid(alpha=0.3) + axm.set_xlabel('Longitude') + axm.legend(loc='upper right', fontsize=8) + + fig_map.suptitle( + f'Event {eve_id} - Stations passing thresholds ({align_phase})', + fontsize=13, + fontweight='bold' + ) + plt.tight_layout() + + map_file = save_dir / f"{eve_id}_station_pass_map_{align_phase}.png" + fig_map.savefig(map_file, dpi=300, bbox_inches='tight') + print(f"✓ Station pass/fail map saved to: {map_file}") + except Exception as e: + print(f"[WARN] Failed to create station pass/fail map (3 components): {e}") + + # ===================== Shift comparison plot: Radial vs Transverse ===================== + if 'R' in all_component_data and 'T' in all_component_data: + print("Creating shift comparison plot (Radial vs Transverse)...") + print( + "Shift comparison parameters: " + f"align_phase={align_phase}, start_time={start_time}, end_time={end_time}, " + f"win_pre={win_pre}, win_post={win_post}, " + f"move_limit_sec={move_limit_sec}" + ) + + r_shifts = all_component_data['R']['station_shifts'] + t_shifts = all_component_data['T']['station_shifts'] + r_corr = all_component_data['R']['station_corr'] + t_corr = all_component_data['T']['station_corr'] + r_calc = all_component_data['R'].get('calc_shifts', {}) + t_calc = all_component_data['T'].get('calc_shifts', {}) + + # Find common stations (require predicted shifts for residual plotting) + common_stations = set(r_shifts.keys()) & set(t_shifts.keys()) + common_stations = common_stations & set(r_calc.keys()) & set(t_calc.keys()) + common_corr_stations = set(r_corr.keys()) & set(t_corr.keys()) + + if len(common_stations) > 0: + # Extract shifts in seconds (remove predicted shift per station) + stations = sorted(common_stations, key=lambda s: int(s)) + r_lags = np.array([r_shifts[sta]['lag_seconds'] - r_calc[sta] for sta in stations], dtype=float) + t_lags = np.array([t_shifts[sta]['lag_seconds'] - t_calc[sta] for sta in stations], dtype=float) + station_nums = np.array([int(sta) for sta in stations], dtype=int) + + pass_r = set(all_component_data['R'].get('pass_window_ids', [])) + pass_t = set(all_component_data['T'].get('pass_window_ids', [])) + pass_mask = np.array([(sta in pass_r) and (sta in pass_t) for sta in stations], dtype=bool) + fail_mask = ~pass_mask + + # Create comparison figure + fig_shift, axes = plt.subplots(2, 3, figsize=(18, 10)) + set_figure_title(fig_shift, f"{eve_id} shift comparison") + (ax1, ax2, ax5), (ax3, ax4, ax6) = axes + + # Panel 1: Scatter plot R vs T (residuals after predicted shift removal) + ax1.scatter(r_lags[pass_mask], t_lags[pass_mask], alpha=0.6, s=20, color='k', label='Pass r_win') + if np.any(fail_mask): + ax1.scatter(r_lags[fail_mask], t_lags[fail_mask], alpha=0.8, s=24, color='red', label='Fail r_win') + ax1.plot([min(r_lags + t_lags), max(r_lags + t_lags)], + [min(r_lags + t_lags), max(r_lags + t_lags)], + 'r--', alpha=0.5, label='1:1 line') + ax1.set_xlabel('Radial residual shift (s)', fontsize=11) + ax1.set_ylabel('Transverse residual shift (s)', fontsize=11) + ax1.set_title('Radial vs Transverse Residuals', fontsize=12, fontweight='bold') + ax1.axvline(-move_limit_sec, color='0.4', linestyle=':', linewidth=1.2) + ax1.axvline(move_limit_sec, color='0.4', linestyle=':', linewidth=1.2) + ax1.axhline(-move_limit_sec, color='0.4', linestyle=':', linewidth=1.2) + ax1.axhline(move_limit_sec, color='0.4', linestyle=':', linewidth=1.2) + ax1.grid(alpha=0.3) + ax1.legend() + ax1.set_aspect('equal', adjustable='box') + + # Panel 2: Difference histogram (residuals) + diff_lags = np.array(r_lags) - np.array(t_lags) + # Fraction of stations with zero R–T shift difference + # (shifts are derived from integer-sample lags; use tiny atol for float safety) + zero_diff_frac = float(np.mean(np.isclose(diff_lags, 0.0, atol=1e-12))) + ax2.hist(diff_lags, bins=30, alpha=0.7, edgecolor='black') + ax2.axvline(0, color='r', linestyle='--', linewidth=2, label='Zero difference') + ax2.axvline(np.median(diff_lags), color='g', linestyle='--', linewidth=2, + label=f'Median = {np.median(diff_lags):.3f}s') + ax2.set_xlabel('R residual - T residual (s)', fontsize=11) + ax2.set_ylabel('Count', fontsize=11) + ax2.set_title('Residual Difference Distribution', fontsize=12, fontweight='bold') + ax2.legend() + ax2.grid(alpha=0.3) + + # Panel 3: Max correlation R vs T + if len(common_corr_stations) > 0: + corr_stations = sorted(common_corr_stations, key=lambda s: int(s)) + r_corr_vals = np.array([r_corr[sta] for sta in corr_stations], dtype=float) + t_corr_vals = np.array([t_corr[sta] for sta in corr_stations], dtype=float) + pass_mask_corr = np.array([(sta in pass_r) and (sta in pass_t) for sta in corr_stations], dtype=bool) + fail_mask_corr = ~pass_mask_corr + ax5.scatter(r_corr_vals[pass_mask_corr], t_corr_vals[pass_mask_corr], alpha=0.6, s=20, color='k', label='Pass r_win') + if np.any(fail_mask_corr): + ax5.scatter(r_corr_vals[fail_mask_corr], t_corr_vals[fail_mask_corr], alpha=0.8, s=24, color='red', label='Fail r_win') + ax5.plot([0, 1], [0, 1], 'r--', alpha=0.5, label='1:1 line') + ax5.set_xlabel('Radial max corr', fontsize=11) + ax5.set_ylabel('Transverse max corr', fontsize=11) + ax5.set_title('Max Correlation: R vs T', fontsize=12, fontweight='bold') + ax5.grid(alpha=0.3) + ax5.legend() + ax5.set_aspect('equal', adjustable='box') + else: + ax5.text(0.5, 0.5, 'No common corr stations', ha='center', va='center') + ax5.set_axis_off() + + # Panel 3: Residual shifts vs station number + ax3.plot(station_nums, r_lags, 'o-', label='Radial', alpha=0.5, markersize=4, color='0.4') + ax3.plot(station_nums, t_lags, 's-', label='Transverse', alpha=0.5, markersize=4, color='0.4') + if np.any(fail_mask): + ax3.scatter(station_nums[fail_mask], r_lags[fail_mask], color='red', s=24, marker='o', label='Fail r_win') + ax3.scatter(station_nums[fail_mask], t_lags[fail_mask], color='red', s=24, marker='s') + ax3.set_xlabel('Station number', fontsize=11) + ax3.set_ylabel('Residual shift (s)', fontsize=11) + ax3.set_title('Residuals vs Station', fontsize=12, fontweight='bold') + ax3.legend() + ax3.grid(alpha=0.3) + + # Panel 4: Statistics + ax4.axis('off') + stats_text = f"""Shift Comparison Statistics + + Number of stations: {len(common_stations)} + Radial shifts: + Mean: {np.mean(r_lags):.4f} s + Std: {np.std(r_lags):.4f} s + Range: [{np.min(r_lags):.4f}, {np.max(r_lags):.4f}] s + + Transverse shifts: + Mean: {np.mean(t_lags):.4f} s + Std: {np.std(t_lags):.4f} s + Range: [{np.min(t_lags):.4f}, {np.max(t_lags):.4f}] s + + Difference (R - T): + Mean: {np.mean(diff_lags):.4f} s + Median: {np.median(diff_lags):.4f} s + Std: {np.std(diff_lags):.4f} s + Zero-difference fraction: {zero_diff_frac*100:.1f}% - # Find common stations (require predicted shifts for residual plotting) - common_stations = set(r_shifts.keys()) & set(t_shifts.keys()) - common_stations = common_stations & set(r_calc.keys()) & set(t_calc.keys()) - common_corr_stations = set(r_corr.keys()) & set(t_corr.keys()) - - if len(common_stations) > 0: - # Extract shifts in seconds (remove predicted shift per station) - stations = sorted(common_stations, key=lambda s: int(s)) - r_lags = np.array([r_shifts[sta]['lag_seconds'] - r_calc[sta] for sta in stations], dtype=float) - t_lags = np.array([t_shifts[sta]['lag_seconds'] - t_calc[sta] for sta in stations], dtype=float) - station_nums = np.array([int(sta) for sta in stations], dtype=int) - - pass_r = set(all_component_data['R'].get('pass_window_ids', [])) - pass_t = set(all_component_data['T'].get('pass_window_ids', [])) - pass_mask = np.array([(sta in pass_r) and (sta in pass_t) for sta in stations], dtype=bool) - fail_mask = ~pass_mask + Frequency content (bandpass): + {min_freq:.2f}–{max_freq:.2f} Hz""" + stats_text += ( + f"\n\nParameters:\n" + f" align_phase: {align_phase}\n" + f" start_time: {start_time}\n" + f" end_time: {end_time}\n" + f" win_pre: {win_pre}\n" + f" win_post: {win_post}\n" + f" move_limit_sec: {move_limit_sec}" + ) + ax4.text(0.1, 0.5, stats_text, fontsize=10, family='monospace', + verticalalignment='center') + + # Panel 6: leave blank + ax6.axis('off') - # Create comparison figure - fig_shift, axes = plt.subplots(2, 3, figsize=(18, 10)) - set_figure_title(fig_shift, f"{eve_id} shift comparison") - (ax1, ax2, ax5), (ax3, ax4, ax6) = axes - - # Panel 1: Scatter plot R vs T (residuals after predicted shift removal) - ax1.scatter(r_lags[pass_mask], t_lags[pass_mask], alpha=0.6, s=20, color='k', label='Pass r_win') - if np.any(fail_mask): - ax1.scatter(r_lags[fail_mask], t_lags[fail_mask], alpha=0.8, s=24, color='red', label='Fail r_win') - ax1.plot([min(r_lags + t_lags), max(r_lags + t_lags)], - [min(r_lags + t_lags), max(r_lags + t_lags)], - 'r--', alpha=0.5, label='1:1 line') - ax1.set_xlabel('Radial residual shift (s)', fontsize=11) - ax1.set_ylabel('Transverse residual shift (s)', fontsize=11) - ax1.set_title('Radial vs Transverse Residuals', fontsize=12, fontweight='bold') - ax1.axvline(-move_limit_sec, color='0.4', linestyle=':', linewidth=1.2) - ax1.axvline(move_limit_sec, color='0.4', linestyle=':', linewidth=1.2) - ax1.axhline(-move_limit_sec, color='0.4', linestyle=':', linewidth=1.2) - ax1.axhline(move_limit_sec, color='0.4', linestyle=':', linewidth=1.2) - ax1.grid(alpha=0.3) - ax1.legend() - ax1.set_aspect('equal', adjustable='box') + fig_shift.suptitle(f'Event {eve_id} - Shift & Correlation Comparison', + fontsize=14, fontweight='bold') + plt.tight_layout() - # Panel 2: Difference histogram (residuals) - diff_lags = np.array(r_lags) - np.array(t_lags) - # Fraction of stations with zero R–T shift difference - # (shifts are derived from integer-sample lags; use tiny atol for float safety) - zero_diff_frac = float(np.mean(np.isclose(diff_lags, 0.0, atol=1e-12))) - ax2.hist(diff_lags, bins=30, alpha=0.7, edgecolor='black') - ax2.axvline(0, color='r', linestyle='--', linewidth=2, label='Zero difference') - ax2.axvline(np.median(diff_lags), color='g', linestyle='--', linewidth=2, - label=f'Median = {np.median(diff_lags):.3f}s') - ax2.set_xlabel('R residual - T residual (s)', fontsize=11) - ax2.set_ylabel('Count', fontsize=11) - ax2.set_title('Residual Difference Distribution', fontsize=12, fontweight='bold') - ax2.legend() - ax2.grid(alpha=0.3) - - # Panel 3: Max correlation R vs T - if len(common_corr_stations) > 0: - corr_stations = sorted(common_corr_stations, key=lambda s: int(s)) - r_corr_vals = np.array([r_corr[sta] for sta in corr_stations], dtype=float) - t_corr_vals = np.array([t_corr[sta] for sta in corr_stations], dtype=float) - pass_mask_corr = np.array([(sta in pass_r) and (sta in pass_t) for sta in corr_stations], dtype=bool) - fail_mask_corr = ~pass_mask_corr - ax5.scatter(r_corr_vals[pass_mask_corr], t_corr_vals[pass_mask_corr], alpha=0.6, s=20, color='k', label='Pass r_win') - if np.any(fail_mask_corr): - ax5.scatter(r_corr_vals[fail_mask_corr], t_corr_vals[fail_mask_corr], alpha=0.8, s=24, color='red', label='Fail r_win') - ax5.plot([0, 1], [0, 1], 'r--', alpha=0.5, label='1:1 line') - ax5.set_xlabel('Radial max corr', fontsize=11) - ax5.set_ylabel('Transverse max corr', fontsize=11) - ax5.set_title('Max Correlation: R vs T', fontsize=12, fontweight='bold') - ax5.grid(alpha=0.3) - ax5.legend() - ax5.set_aspect('equal', adjustable='box') + # Save shift comparison plot + shift_file = save_dir / f"{eve_id}_shift_comparison_{align_phase}.png" + fig_shift.savefig(shift_file, dpi=300, bbox_inches='tight') + print(f"✓ Shift comparison plot saved to: {shift_file}") else: - ax5.text(0.5, 0.5, 'No common corr stations', ha='center', va='center') - ax5.set_axis_off() - - # Panel 3: Residual shifts vs station number - ax3.plot(station_nums, r_lags, 'o-', label='Radial', alpha=0.5, markersize=4, color='0.4') - ax3.plot(station_nums, t_lags, 's-', label='Transverse', alpha=0.5, markersize=4, color='0.4') + print("Warning: No common stations found between R and T components") + + # ===================== Estimated vs calculated shift plot (3 components) ===================== + print("Creating estimated vs calculated shift plot (3 components)...") + + fig_ec, axes_ec = plt.subplots(1, 3, figsize=(15, 4.2), sharex=True, sharey=True) + set_figure_title(fig_ec, f"{eve_id} est vs calc shifts (3-comp)") + comp_order = ['DPZ', 'R', 'T'] + comp_titles_ec = ['Z', 'R', 'T'] + + for j, comp_name in enumerate(comp_order): + axc = axes_ec[j] + if comp_name not in all_component_data: + axc.set_axis_off() + continue + + data = all_component_data[comp_name] + station_shifts = data.get('station_shifts', {}) + calc_shifts = data.get('calc_shifts', {}) + pass_set = set(data.get('pass_window_ids', [])) + + common_sta = set(calc_shifts.keys()) & set(station_shifts.keys()) + if len(common_sta) == 0: + axc.text(0.5, 0.5, 'No common stations', ha='center', va='center') + axc.set_axis_off() + continue + + stations = sorted(common_sta, key=lambda s: int(s)) + est_shift = np.array([station_shifts[s]['lag_seconds'] for s in stations], dtype=float) + calc_shift = np.array([calc_shifts[s] for s in stations], dtype=float) + pass_mask = np.array([s in pass_set for s in stations], dtype=bool) + fail_mask = ~pass_mask + + axc.scatter(calc_shift[pass_mask], est_shift[pass_mask], s=18, alpha=0.6, color='k', label='Pass r_win') if np.any(fail_mask): - ax3.scatter(station_nums[fail_mask], r_lags[fail_mask], color='red', s=24, marker='o', label='Fail r_win') - ax3.scatter(station_nums[fail_mask], t_lags[fail_mask], color='red', s=24, marker='s') - ax3.set_xlabel('Station number', fontsize=11) - ax3.set_ylabel('Residual shift (s)', fontsize=11) - ax3.set_title('Residuals vs Station', fontsize=12, fontweight='bold') - ax3.legend() - ax3.grid(alpha=0.3) - - # Panel 4: Statistics - ax4.axis('off') - stats_text = f"""Shift Comparison Statistics - - Number of stations: {len(common_stations)} - Radial shifts: - Mean: {np.mean(r_lags):.4f} s - Std: {np.std(r_lags):.4f} s - Range: [{np.min(r_lags):.4f}, {np.max(r_lags):.4f}] s - - Transverse shifts: - Mean: {np.mean(t_lags):.4f} s - Std: {np.std(t_lags):.4f} s - Range: [{np.min(t_lags):.4f}, {np.max(t_lags):.4f}] s - - Difference (R - T): - Mean: {np.mean(diff_lags):.4f} s - Median: {np.median(diff_lags):.4f} s - Std: {np.std(diff_lags):.4f} s - Zero-difference fraction: {zero_diff_frac*100:.1f}% - - Frequency content (bandpass): - {min_freq:.2f}–{max_freq:.2f} Hz""" - stats_text += ( - f"\n\nParameters:\n" - f" align_phase: {align_phase}\n" - f" start_time: {start_time}\n" - f" end_time: {end_time}\n" - f" win_pre: {win_pre}\n" - f" win_post: {win_post}\n" - f" move_limit_sec: {move_limit_sec}" + axc.scatter(calc_shift[fail_mask], est_shift[fail_mask], s=22, alpha=0.8, color='red', label='Fail r_win') + + minv = float(min(np.min(calc_shift), np.min(est_shift))) + maxv = float(max(np.max(calc_shift), np.max(est_shift))) + axc.plot([minv, maxv], [minv, maxv], 'r--', lw=1.2, alpha=0.7) + axc.plot( + [minv, maxv], + [minv + move_limit_sec, maxv + move_limit_sec], + color='0.4', + linestyle=':', + lw=1.2, + ) + axc.plot( + [minv, maxv], + [minv - move_limit_sec, maxv - move_limit_sec], + color='0.4', + linestyle=':', + lw=1.2, + ) + + axc.set_title(comp_titles_ec[j], fontsize=12, fontweight='bold') + axc.grid(alpha=0.3) + axc.set_xlabel('Calculated shift (s)', fontsize=10) + if j == 0: + axc.set_ylabel('Estimated shift (s)', fontsize=10) + axc.legend(loc='upper left', fontsize=8) + + fig_ec.suptitle( + f'Event {eve_id} - Estimated vs Calculated shifts ({align_phase})', + fontsize=13, + fontweight='bold' ) - ax4.text(0.1, 0.5, stats_text, fontsize=10, family='monospace', - verticalalignment='center') - - # Panel 6: leave blank - ax6.axis('off') - - fig_shift.suptitle(f'Event {eve_id} - Shift & Correlation Comparison', - fontsize=14, fontweight='bold') plt.tight_layout() - - # Save shift comparison plot - shift_file = save_dir / f"{eve_id}_shift_comparison_{align_phase}.png" - fig_shift.savefig(shift_file, dpi=300, bbox_inches='tight') - print(f"✓ Shift comparison plot saved to: {shift_file}") - else: - print("Warning: No common stations found between R and T components") - - # ===================== Estimated vs calculated shift plot (3 components) ===================== - print("Creating estimated vs calculated shift plot (3 components)...") - - fig_ec, axes_ec = plt.subplots(1, 3, figsize=(15, 4.2), sharex=True, sharey=True) - set_figure_title(fig_ec, f"{eve_id} est vs calc shifts (3-comp)") - comp_order = ['DPZ', 'R', 'T'] - comp_titles_ec = ['Z', 'R', 'T'] - - for j, comp_name in enumerate(comp_order): - axc = axes_ec[j] - if comp_name not in all_component_data: - axc.set_axis_off() - continue - - data = all_component_data[comp_name] - station_shifts = data.get('station_shifts', {}) - calc_shifts = data.get('calc_shifts', {}) - pass_set = set(data.get('pass_window_ids', [])) - - common_sta = set(calc_shifts.keys()) & set(station_shifts.keys()) - if len(common_sta) == 0: - axc.text(0.5, 0.5, 'No common stations', ha='center', va='center') - axc.set_axis_off() - continue - - stations = sorted(common_sta, key=lambda s: int(s)) - est_shift = np.array([station_shifts[s]['lag_seconds'] for s in stations], dtype=float) - calc_shift = np.array([calc_shifts[s] for s in stations], dtype=float) - pass_mask = np.array([s in pass_set for s in stations], dtype=bool) - fail_mask = ~pass_mask - - axc.scatter(calc_shift[pass_mask], est_shift[pass_mask], s=18, alpha=0.6, color='k', label='Pass r_win') - if np.any(fail_mask): - axc.scatter(calc_shift[fail_mask], est_shift[fail_mask], s=22, alpha=0.8, color='red', label='Fail r_win') - - minv = float(min(np.min(calc_shift), np.min(est_shift))) - maxv = float(max(np.max(calc_shift), np.max(est_shift))) - axc.plot([minv, maxv], [minv, maxv], 'r--', lw=1.2, alpha=0.7) - axc.plot( - [minv, maxv], - [minv + move_limit_sec, maxv + move_limit_sec], - color='0.4', - linestyle=':', - lw=1.2, - ) - axc.plot( - [minv, maxv], - [minv - move_limit_sec, maxv - move_limit_sec], - color='0.4', - linestyle=':', - lw=1.2, - ) - - axc.set_title(comp_titles_ec[j], fontsize=12, fontweight='bold') - axc.grid(alpha=0.3) - axc.set_xlabel('Calculated shift (s)', fontsize=10) - if j == 0: - axc.set_ylabel('Estimated shift (s)', fontsize=10) - axc.legend(loc='upper left', fontsize=8) - - fig_ec.suptitle( - f'Event {eve_id} - Estimated vs Calculated shifts ({align_phase})', - fontsize=13, - fontweight='bold' - ) - plt.tight_layout() - - estcalc_file = save_dir / f"{eve_id}_est_vs_calc_shift_{align_phase}.png" - fig_ec.savefig(estcalc_file, dpi=300, bbox_inches='tight') - print(f"✓ Estimated vs calculated shift plot saved to: {estcalc_file}") - - add_stage_timing("plot_three_component", _plot3_wall_start, _plot3_cpu_start) + + estcalc_file = save_dir / f"{eve_id}_est_vs_calc_shift_{align_phase}.png" + fig_ec.savefig(estcalc_file, dpi=300, bbox_inches='tight') + print(f"✓ Estimated vs calculated shift plot saved to: {estcalc_file}") + + add_stage_timing("plot_three_component", _plot3_wall_start, _plot3_cpu_start) + + # Show all figures together (three-component + shift comparison) + # plt.show() + + # ===================== Show all figures together at the end ===================== + report_timing_once() + print("\a\a\a") + plt.show() - # Show all figures together (three-component + shift comparison) - # plt.show() +def main() -> None: + run_pipeline() - # ===================== Show all figures together at the end ===================== - report_timing_once() - print("\a\a\a") - plt.show() +if __name__ == "__main__": + main() From bea5675615c74eea5f2cb64e13232b493a6aa3a3 Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 16:11:05 -0700 Subject: [PATCH 2/9] Extract core alignment utilities into align_utils module --- align_stack.py | 160 ++++--------------------------------------------- align_utils.py | 138 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 147 deletions(-) create mode 100644 align_utils.py diff --git a/align_stack.py b/align_stack.py index dae11d4..09c3af2 100644 --- a/align_stack.py +++ b/align_stack.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt from matplotlib.lines import Line2D from pathlib import Path -from datetime import timedelta, timezone +from datetime import timezone import time from obspy import read, UTCDateTime, Stream, Trace @@ -13,6 +13,18 @@ from scipy.signal import hilbert from scipy.signal.windows import gaussian +from align_utils import ( + add_catalog_event_lines, + add_utc_time_axis, + compute_lag, + correlation_time_bounds, + draw_correlation_markers, + ensure_utc_datetime, + get_component_selection, + set_figure_title, + shift_left_zeropad, +) + min_freq, max_freq = 3.0, 10.0 # Bandpass filter (Hz) start_time, end_time = -10.0, 20 # Plotting time window (seconds since origin) # start_time, end_time = -1990.0, 3690.0 # Plotting time window (seconds since origin) @@ -95,64 +107,6 @@ def report_stage_timing() -> None: model = TauPyModel(model="iasp91") # ===================== Helper functions ===================== -def compute_lag( - ref: np.ndarray, - d: np.ndarray, - win_start: int, - win_end: int, - move_limit_samples: int, -) -> int: - """Compute integer lag (samples) by maximizing correlation within a short window. - - This matches the original implementation: - - ref_window = ref[win_start:win_end] - - d_window = d[win_start-move_limit_samples : win_end+move_limit_samples] - - corr = np.correlate(d_window, ref_window, mode='valid') - - lag = argmax(corr) - move_limit_samples - - Returns: - Best lag (integer samples). Positive lag advances the target waveform. - """ - ref_window = ref[win_start:win_end] - d_window = d[win_start - move_limit_samples : win_end + move_limit_samples] - corr = np.correlate(d_window, ref_window, mode="valid") - return int(np.argmax(corr) - move_limit_samples) - - -def shift_left_zeropad(x: np.ndarray, n: int) -> np.ndarray: - """Shift 1D array left by n samples with zero padding (no wrap-around). - - Equivalent to np.roll(x, -n) but WITHOUT circular wrap. - - n > 0: advance in time - - n < 0: delay - """ - x = np.asarray(x) - y = np.zeros_like(x) - - if n == 0: - y[:] = x - return y - - if n > 0: - if n >= x.size: - return y - y[:-n] = x[n:] - return y - - # n < 0 - n = -n - if n >= x.size: - return y - y[n:] = x[:-n] - return y - - -def set_figure_title(fig, title: str) -> None: - """Set a descriptive window title if the backend supports it.""" - try: - fig.canvas.manager.set_window_title(title) - except Exception: - pass def report_timing_once() -> None: @@ -167,94 +121,6 @@ def report_timing_once() -> None: _timing_reported = True -def add_catalog_event_lines(ax, origin_time, catalog_df, tmin, tmax) -> None: - """Draw vertical lines for each catalog event time on a time-since-origin axis.""" - if origin_time is None or catalog_df is None: - return - if "origin_time" not in catalog_df.columns: - print("[WARN] Catalog missing 'origin_time' column; no event lines drawn.") - return - - color_map = {0: "red", 1: "black", 2: "green"} - for _, row in catalog_df.iterrows(): - try: - evt_time = UTCDateTime(str(row["origin_time"])) - except Exception: - continue - dt = float(evt_time - origin_time) - if dt < tmin or dt > tmax: - continue - skip_val = row.get("skip", 0) - try: - skip_int = int(skip_val) - except Exception: - skip_int = 0 - color = color_map.get(skip_int, "red") - ax.axvline(x=dt, color=color, lw=1.1, alpha=0.8, zorder=6) - - -def get_component_selection(all_channels_mode: bool, comp: str): - """Return (channels, process_as_three_comp, selected_components).""" - if all_channels_mode: - return ["DPZ", "DP1", "DP2"], True, ["Z", "R", "T"] - if comp == "Z": - return ["DPZ"], False, ["Z"] - if comp in ("R", "T"): - # Single-component R/T still reads both horizontals for rotation. - return ["DP1"], False, [comp] - raise ValueError("component must be 'Z', 'R', or 'T'") - - -def ensure_utc_datetime(dt_obj): - """Return a timezone-aware UTC datetime for printing/labeling.""" - if dt_obj.tzinfo is None: - return dt_obj.replace(tzinfo=timezone.utc) - return dt_obj.astimezone(timezone.utc) - - -def correlation_time_bounds(start_t, win_start_samp, win_end_samp, samp_rate, move_sec, npts): - """Compute correlation window and search bounds in seconds since origin.""" - t_win_start = start_t + (win_start_samp / samp_rate) - t_win_end = start_t + (win_end_samp / samp_rate) - t_explore_start = max(start_t, t_win_start - move_sec) - t_explore_end = min(start_t + (npts / samp_rate), t_win_end + move_sec) - return t_win_start, t_win_end, t_explore_start, t_explore_end - - -def draw_correlation_markers(ax, start_t, win_start_samp, win_end_samp, samp_rate, move_sec, npts): - """Draw yellow (window) and green (search) vertical bounds on one axis.""" - t_win_start, t_win_end, t_explore_start, t_explore_end = correlation_time_bounds( - start_t, win_start_samp, win_end_samp, samp_rate, move_sec, npts - ) - ax.axvline(x=t_win_start, color="y", lw=2, alpha=0.9, zorder=7) - ax.axvline(x=t_win_end, color="y", lw=2, alpha=0.9, zorder=7) - ax.axvline(x=t_explore_start, color="g", lw=2, alpha=0.9, zorder=7) - ax.axvline(x=t_explore_end, color="g", lw=2, alpha=0.9, zorder=7) - - -def add_utc_time_axis(ax, origin_time, tick_tz=timezone.utc, label_size: int = 10) -> None: - """Add a bottom UTC axis that mirrors the primary x-axis ticks.""" - if origin_time is None: - return - origin_dt_utc = ensure_utc_datetime(origin_time.datetime) - ax_time = ax.twiny() - ax_time.set_xlim(ax.get_xlim()) - ax_time.xaxis.set_label_position("bottom") - ax_time.xaxis.set_ticks_position("bottom") - ax_time.spines["bottom"].set_position(("outward", 36)) - ax_time.spines["top"].set_visible(False) - - ticks = ax.get_xticks() - labels = [ - (origin_dt_utc + timedelta(seconds=float(t))).astimezone(tick_tz).strftime("%H:%M:%S") - for t in ticks - ] - ax_time.set_xticks(ticks) - ax_time.set_xticklabels(labels) - date_str = origin_dt_utc.date().isoformat() - ax_time.set_xlabel(f"UTC time ({date_str})", fontsize=label_size) - - def plot_stage_stacks( eve_id: str, plot_comp: str, diff --git a/align_utils.py b/align_utils.py new file mode 100644 index 0000000..b943f82 --- /dev/null +++ b/align_utils.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from datetime import timedelta, timezone + +import numpy as np +from obspy import UTCDateTime + + +def compute_lag( + ref: np.ndarray, + d: np.ndarray, + win_start: int, + win_end: int, + move_limit_samples: int, +) -> int: + """Compute integer lag (samples) by maximizing correlation within a short window.""" + ref_window = ref[win_start:win_end] + d_window = d[win_start - move_limit_samples : win_end + move_limit_samples] + corr = np.correlate(d_window, ref_window, mode="valid") + return int(np.argmax(corr) - move_limit_samples) + + +def shift_left_zeropad(x: np.ndarray, n: int) -> np.ndarray: + """Shift 1D array left by n samples with zero padding (no wrap-around).""" + x = np.asarray(x) + y = np.zeros_like(x) + + if n == 0: + y[:] = x + return y + + if n > 0: + if n >= x.size: + return y + y[:-n] = x[n:] + return y + + n = -n + if n >= x.size: + return y + y[n:] = x[:-n] + return y + + +def ensure_utc_datetime(dt_obj): + """Return a timezone-aware UTC datetime for printing/labeling.""" + if dt_obj.tzinfo is None: + return dt_obj.replace(tzinfo=timezone.utc) + return dt_obj.astimezone(timezone.utc) + + +def correlation_time_bounds(start_t, win_start_samp, win_end_samp, samp_rate, move_sec, npts): + """Compute correlation window and search bounds in seconds since origin.""" + t_win_start = start_t + (win_start_samp / samp_rate) + t_win_end = start_t + (win_end_samp / samp_rate) + t_explore_start = max(start_t, t_win_start - move_sec) + t_explore_end = min(start_t + (npts / samp_rate), t_win_end + move_sec) + return t_win_start, t_win_end, t_explore_start, t_explore_end + + +def draw_correlation_markers(ax, start_t, win_start_samp, win_end_samp, samp_rate, move_sec, npts): + """Draw yellow (window) and green (search) vertical bounds on one axis.""" + t_win_start, t_win_end, t_explore_start, t_explore_end = correlation_time_bounds( + start_t, win_start_samp, win_end_samp, samp_rate, move_sec, npts + ) + ax.axvline(x=t_win_start, color="y", lw=2, alpha=0.9, zorder=7) + ax.axvline(x=t_win_end, color="y", lw=2, alpha=0.9, zorder=7) + ax.axvline(x=t_explore_start, color="g", lw=2, alpha=0.9, zorder=7) + ax.axvline(x=t_explore_end, color="g", lw=2, alpha=0.9, zorder=7) + + +def set_figure_title(fig, title: str) -> None: + """Set a descriptive window title if the backend supports it.""" + try: + fig.canvas.manager.set_window_title(title) + except Exception: + pass + + +def get_component_selection(all_channels_mode: bool, comp: str): + """Return (channels, process_as_three_comp, selected_components).""" + if all_channels_mode: + return ["DPZ", "DP1", "DP2"], True, ["Z", "R", "T"] + if comp == "Z": + return ["DPZ"], False, ["Z"] + if comp in ("R", "T"): + # Single-component R/T still reads both horizontals for rotation. + return ["DP1"], False, [comp] + raise ValueError("component must be 'Z', 'R', or 'T'") + + +def add_catalog_event_lines(ax, origin_time, catalog_df, tmin, tmax) -> None: + """Draw vertical lines for each catalog event time on a time-since-origin axis.""" + if origin_time is None or catalog_df is None: + return + if "origin_time" not in catalog_df.columns: + print("[WARN] Catalog missing 'origin_time' column; no event lines drawn.") + return + + color_map = {0: "red", 1: "black", 2: "green"} + for _, row in catalog_df.iterrows(): + try: + evt_time = UTCDateTime(str(row["origin_time"])) + except Exception: + continue + dt = float(evt_time - origin_time) + if dt < tmin or dt > tmax: + continue + skip_val = row.get("skip", 0) + try: + skip_int = int(skip_val) + except Exception: + skip_int = 0 + color = color_map.get(skip_int, "red") + ax.axvline(x=dt, color=color, lw=1.1, alpha=0.8, zorder=6) + + +def add_utc_time_axis(ax, origin_time, tick_tz=timezone.utc, label_size: int = 10) -> None: + """Add a bottom UTC axis that mirrors the primary x-axis ticks.""" + if origin_time is None: + return + origin_dt_utc = ensure_utc_datetime(origin_time.datetime) + ax_time = ax.twiny() + ax_time.set_xlim(ax.get_xlim()) + ax_time.xaxis.set_label_position("bottom") + ax_time.xaxis.set_ticks_position("bottom") + ax_time.spines["bottom"].set_position(("outward", 36)) + ax_time.spines["top"].set_visible(False) + + ticks = ax.get_xticks() + labels = [ + (origin_dt_utc + timedelta(seconds=float(t))).astimezone(tick_tz).strftime("%H:%M:%S") + for t in ticks + ] + ax_time.set_xticks(ticks) + ax_time.set_xticklabels(labels) + date_str = origin_dt_utc.date().isoformat() + ax_time.set_xlabel(f"UTC time ({date_str})", fontsize=label_size) From 9be3b881cac0bbf5efac3d25a94cb6cae2f7e5e3 Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 16:18:04 -0700 Subject: [PATCH 3/9] Add output regression comparison script --- scripts/compare_outputs.py | 205 +++++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 scripts/compare_outputs.py diff --git a/scripts/compare_outputs.py b/scripts/compare_outputs.py new file mode 100644 index 0000000..7e50e8e --- /dev/null +++ b/scripts/compare_outputs.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +"""Compare two output directories for regression checks. + +Checks: +- File presence by relative path +- PNG file size deltas +- MSEED basic stats (trace count, npts, sampling_rate, max_abs) + +Exit code is 0 when no differences are found, else 1. +""" + +from __future__ import annotations + +import argparse +import hashlib +import math +from dataclasses import dataclass +from pathlib import Path + +from obspy import read + + +@dataclass +class MseedStats: + trace_count: int + npts_total: int + sampling_rates: tuple[float, ...] + max_abs: float + + +@dataclass +class ComparisonResult: + ok: bool + lines: list[str] + + +def sha256_file(path: Path, chunk_size: int = 1 << 20) -> str: + digest = hashlib.sha256() + with path.open("rb") as f: + while True: + chunk = f.read(chunk_size) + if not chunk: + break + digest.update(chunk) + return digest.hexdigest() + + +def list_files(root: Path) -> dict[str, Path]: + files: dict[str, Path] = {} + for p in root.rglob("*"): + if p.is_file(): + files[str(p.relative_to(root))] = p + return files + + +def mseed_stats(path: Path) -> MseedStats: + st = read(str(path)) + sampling_rates = tuple(float(tr.stats.sampling_rate) for tr in st) + npts_total = int(sum(int(tr.stats.npts) for tr in st)) + max_abs = 0.0 + for tr in st: + if tr.data.size == 0: + continue + candidate = float(abs(tr.data).max()) + if candidate > max_abs: + max_abs = candidate + return MseedStats( + trace_count=len(st), + npts_total=npts_total, + sampling_rates=sampling_rates, + max_abs=max_abs, + ) + + +def compare_dirs( + baseline: Path, + candidate: Path, + png_size_tol_pct: float, + mseed_amp_tol: float, + check_hash: bool, +) -> ComparisonResult: + lines: list[str] = [] + ok = True + + base_files = list_files(baseline) + cand_files = list_files(candidate) + + base_set = set(base_files.keys()) + cand_set = set(cand_files.keys()) + + missing = sorted(base_set - cand_set) + extra = sorted(cand_set - base_set) + + if missing: + ok = False + lines.append("Missing files in candidate:") + lines.extend(f" - {m}" for m in missing) + if extra: + ok = False + lines.append("Extra files in candidate:") + lines.extend(f" - {e}" for e in extra) + + common = sorted(base_set & cand_set) + + for rel in common: + b = base_files[rel] + c = cand_files[rel] + + if rel.lower().endswith(".png"): + b_size = b.stat().st_size + c_size = c.stat().st_size + denom = max(1, b_size) + pct = 100.0 * abs(c_size - b_size) / denom + if pct > png_size_tol_pct: + ok = False + lines.append( + f"PNG size delta too large: {rel} baseline={b_size} candidate={c_size} delta={pct:.2f}%" + ) + if check_hash and sha256_file(b) != sha256_file(c): + lines.append(f"PNG hash changed: {rel}") + + elif rel.lower().endswith(".mseed"): + bs = mseed_stats(b) + cs = mseed_stats(c) + + if bs.trace_count != cs.trace_count: + ok = False + lines.append( + f"MSEED trace count changed: {rel} baseline={bs.trace_count} candidate={cs.trace_count}" + ) + if bs.npts_total != cs.npts_total: + ok = False + lines.append( + f"MSEED npts_total changed: {rel} baseline={bs.npts_total} candidate={cs.npts_total}" + ) + if bs.sampling_rates != cs.sampling_rates: + ok = False + lines.append( + f"MSEED sampling_rates changed: {rel} baseline={bs.sampling_rates} candidate={cs.sampling_rates}" + ) + + if not math.isclose(bs.max_abs, cs.max_abs, rel_tol=0.0, abs_tol=mseed_amp_tol): + ok = False + lines.append( + f"MSEED max_abs changed: {rel} baseline={bs.max_abs:.6g} candidate={cs.max_abs:.6g}" + ) + + if ok: + lines.append("No regressions detected under current comparison rules.") + + return ComparisonResult(ok=ok, lines=lines) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("baseline", type=Path, help="Baseline output directory") + parser.add_argument("candidate", type=Path, help="Candidate output directory") + parser.add_argument( + "--png-size-tol-pct", + type=float, + default=2.0, + help="Allowed PNG size delta percentage (default: 2.0)", + ) + parser.add_argument( + "--mseed-amp-abs-tol", + type=float, + default=1e-6, + help="Allowed absolute tolerance for MSEED max amplitude (default: 1e-6)", + ) + parser.add_argument( + "--check-hash", + action="store_true", + help="Also report content hash differences for PNGs", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + baseline = args.baseline.resolve() + candidate = args.candidate.resolve() + + if not baseline.is_dir(): + print(f"Baseline directory not found: {baseline}") + return 2 + if not candidate.is_dir(): + print(f"Candidate directory not found: {candidate}") + return 2 + + result = compare_dirs( + baseline=baseline, + candidate=candidate, + png_size_tol_pct=float(args.png_size_tol_pct), + mseed_amp_tol=float(args.mseed_amp_abs_tol), + check_hash=bool(args.check_hash), + ) + + for line in result.lines: + print(line) + + return 0 if result.ok else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) From ea521f5df3e53e927dab3073bd01cdc12be39e45 Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 16:25:21 -0700 Subject: [PATCH 4/9] Ignore Python bytecode and cache files --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4d1db3d --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +# Python bytecode/cache +__pycache__/ +*.py[cod] +*$py.class + +# macOS +.DS_Store From 2ce2d786808376d8e7619214c46cb3b1406fb96e Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 16:27:12 -0700 Subject: [PATCH 5/9] Extract timing helpers into align_utils TimingState --- align_stack.py | 69 +++++++++++--------------------------------------- align_utils.py | 48 +++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 54 deletions(-) diff --git a/align_stack.py b/align_stack.py index 09c3af2..9a21ff0 100644 --- a/align_stack.py +++ b/align_stack.py @@ -15,14 +15,17 @@ from align_utils import ( add_catalog_event_lines, + add_stage_timing, add_utc_time_axis, compute_lag, correlation_time_bounds, draw_correlation_markers, ensure_utc_datetime, get_component_selection, + report_timing_once, set_figure_title, shift_left_zeropad, + TimingState, ) min_freq, max_freq = 3.0, 10.0 # Bandpass filter (Hz) @@ -59,37 +62,7 @@ # Timing (cpu and wall) -_start_cpu_time = time.process_time() -_start_wall_time = time.perf_counter() -_timing_reported = False -_stage_wall_times = {} -_stage_cpu_times = {} -_stage_counts = {} - - -def add_stage_timing(stage_name: str, wall_start: float, cpu_start: float) -> None: - """Accumulate elapsed wall/cpu time for a named processing stage.""" - wall_dt = time.perf_counter() - wall_start - cpu_dt = time.process_time() - cpu_start - _stage_wall_times[stage_name] = _stage_wall_times.get(stage_name, 0.0) + wall_dt - _stage_cpu_times[stage_name] = _stage_cpu_times.get(stage_name, 0.0) + cpu_dt - _stage_counts[stage_name] = _stage_counts.get(stage_name, 0) + 1 - - -def report_stage_timing() -> None: - """Print stage-level timing summary sorted by wall time.""" - if not _stage_wall_times: - return - total_wall = sum(_stage_wall_times.values()) - print("\033[36mStage timing breakdown (wall/cpu):\033[0m") - for name, wall_sec in sorted(_stage_wall_times.items(), key=lambda kv: kv[1], reverse=True): - cpu_sec = _stage_cpu_times.get(name, 0.0) - calls = _stage_counts.get(name, 0) - frac = (100.0 * wall_sec / total_wall) if total_wall > 0 else 0.0 - print( - f" {name:<28} wall={wall_sec:7.2f}s cpu={cpu_sec:7.2f}s " - f"calls={calls:3d} ({frac:5.1f}%)" - ) +timing_state = TimingState() catalog_local_file = info_root / "catalog_local_hand.xlsx" @@ -109,18 +82,6 @@ def report_stage_timing() -> None: # ===================== Helper functions ===================== -def report_timing_once() -> None: - """Report cpu and wall time before showing plots.""" - global _timing_reported - if _timing_reported: - return - cpu_sec = time.process_time() - _start_cpu_time - wall_sec = time.perf_counter() - _start_wall_time - print(f"\033[31mTiming: cpu={cpu_sec:.2f}s wall={wall_sec:.2f}s\033[0m") - report_stage_timing() - _timing_reported = True - - def plot_stage_stacks( eve_id: str, plot_comp: str, @@ -491,7 +452,7 @@ def read_waveforms_for_event( if stanum % 100 == 0: print(f"Event {eve_id}: processed {stanum} stations...") first_chan_for_sta = False - add_stage_timing("waveform_read_slice", _read_wall_start, _read_cpu_start) + add_stage_timing(timing_state, "waveform_read_slice", _read_wall_start, _read_cpu_start) # ---- Sort by distance ---- st_all.sort(keys=["dist_km"]) @@ -668,7 +629,7 @@ def compute_alignment_products( if t_sta is not None: calc_shifts[station_id] = t_sta - t_ref - add_stage_timing("taup_station_shifts", _taup_wall_start, _taup_cpu_start) + add_stage_timing(timing_state, "taup_station_shifts", _taup_wall_start, _taup_cpu_start) # ===================== Stage 1: align to reference -> aligned_stack ===================== aligned_stack = np.zeros(npts) @@ -693,7 +654,7 @@ def compute_alignment_products( ) aligned_stack += shift_left_zeropad(d, lag1) - add_stage_timing("align_stage1", _stage1_wall_start, _stage1_cpu_start) + add_stage_timing(timing_state, "align_stage1", _stage1_wall_start, _stage1_cpu_start) win = aligned_stack[win_start:win_end] mx = np.max(np.abs(win)) if win.size > 0 else 0.0 @@ -757,7 +718,7 @@ def compute_alignment_products( selected_ids.add(station_id) else: print(f" Rejected {station_id}: r_win={r_window:.2f}") - add_stage_timing("align_stage2_screen", _stage2_wall_start, _stage2_cpu_start) + add_stage_timing(timing_state, "align_stage2_screen", _stage2_wall_start, _stage2_cpu_start) win = selected_aligned_stack[win_start:win_end] mx = np.max(np.abs(win)) if win.size > 0 else 0.0 @@ -816,7 +777,7 @@ def compute_alignment_products( aligned_bank.append(y) else: rejected_rows.append((dist_km, station_id, y)) - add_stage_timing("align_stage3_finalize", _stage3_wall_start, _stage3_cpu_start) + add_stage_timing(timing_state, "align_stage3_finalize", _stage3_wall_start, _stage3_cpu_start) selected_rows.sort(key=lambda t: t[0]) rejected_rows.sort(key=lambda t: t[0]) @@ -954,7 +915,7 @@ def run_pipeline() -> None: rotated_traces.append(trT) plot_comp = "T" - add_stage_timing("rotate_to_rt", _rotate_wall_start, _rotate_cpu_start) + add_stage_timing(timing_state, "rotate_to_rt", _rotate_wall_start, _rotate_cpu_start) st_comp = Stream(traces=rotated_traces) else: @@ -996,7 +957,7 @@ def run_pipeline() -> None: corners=4, zerophase=True, ) - add_stage_timing("preprocess_filter", _pre_wall_start, _pre_cpu_start) + add_stage_timing(timing_state, "preprocess_filter", _pre_wall_start, _pre_cpu_start) alignment = compute_alignment_products( st_comp=st_comp, @@ -1361,10 +1322,10 @@ def run_pipeline() -> None: except Exception as e: print(f"[WARN] Failed to create station pass/fail maps: {e}") - add_stage_timing("plot_and_save", _plot_wall_start, _plot_cpu_start) + add_stage_timing(timing_state, "plot_and_save", _plot_wall_start, _plot_cpu_start) # Show figures for single-component mode - report_timing_once() + report_timing_once(timing_state) plt.show() @@ -2080,13 +2041,13 @@ def run_pipeline() -> None: fig_ec.savefig(estcalc_file, dpi=300, bbox_inches='tight') print(f"✓ Estimated vs calculated shift plot saved to: {estcalc_file}") - add_stage_timing("plot_three_component", _plot3_wall_start, _plot3_cpu_start) + add_stage_timing(timing_state, "plot_three_component", _plot3_wall_start, _plot3_cpu_start) # Show all figures together (three-component + shift comparison) # plt.show() # ===================== Show all figures together at the end ===================== - report_timing_once() + report_timing_once(timing_state) print("\a\a\a") plt.show() diff --git a/align_utils.py b/align_utils.py index b943f82..cbda487 100644 --- a/align_utils.py +++ b/align_utils.py @@ -1,11 +1,59 @@ from __future__ import annotations from datetime import timedelta, timezone +from dataclasses import dataclass, field +import time import numpy as np from obspy import UTCDateTime +@dataclass +class TimingState: + start_cpu_time: float = field(default_factory=time.process_time) + start_wall_time: float = field(default_factory=time.perf_counter) + timing_reported: bool = False + stage_wall_times: dict[str, float] = field(default_factory=dict) + stage_cpu_times: dict[str, float] = field(default_factory=dict) + stage_counts: dict[str, int] = field(default_factory=dict) + + +def add_stage_timing(state: TimingState, stage_name: str, wall_start: float, cpu_start: float) -> None: + """Accumulate elapsed wall/cpu time for a named processing stage.""" + wall_dt = time.perf_counter() - wall_start + cpu_dt = time.process_time() - cpu_start + state.stage_wall_times[stage_name] = state.stage_wall_times.get(stage_name, 0.0) + wall_dt + state.stage_cpu_times[stage_name] = state.stage_cpu_times.get(stage_name, 0.0) + cpu_dt + state.stage_counts[stage_name] = state.stage_counts.get(stage_name, 0) + 1 + + +def report_stage_timing(state: TimingState) -> None: + """Print stage-level timing summary sorted by wall time.""" + if not state.stage_wall_times: + return + total_wall = sum(state.stage_wall_times.values()) + print("\033[36mStage timing breakdown (wall/cpu):\033[0m") + for name, wall_sec in sorted(state.stage_wall_times.items(), key=lambda kv: kv[1], reverse=True): + cpu_sec = state.stage_cpu_times.get(name, 0.0) + calls = state.stage_counts.get(name, 0) + frac = (100.0 * wall_sec / total_wall) if total_wall > 0 else 0.0 + print( + f" {name:<28} wall={wall_sec:7.2f}s cpu={cpu_sec:7.2f}s " + f"calls={calls:3d} ({frac:5.1f}%)" + ) + + +def report_timing_once(state: TimingState) -> None: + """Report cpu and wall time before showing plots.""" + if state.timing_reported: + return + cpu_sec = time.process_time() - state.start_cpu_time + wall_sec = time.perf_counter() - state.start_wall_time + print(f"\033[31mTiming: cpu={cpu_sec:.2f}s wall={wall_sec:.2f}s\033[0m") + report_stage_timing(state) + state.timing_reported = True + + def compute_lag( ref: np.ndarray, d: np.ndarray, From b2d53b975fcee88604dadf133635a2ea2a5f69e5 Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 17:06:37 -0700 Subject: [PATCH 6/9] Extract payload and output-dir helpers into align_utils --- align_stack.py | 90 ++------------------------------------------------ align_utils.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 87 deletions(-) diff --git a/align_stack.py b/align_stack.py index 9a21ff0..23fb026 100644 --- a/align_stack.py +++ b/align_stack.py @@ -17,12 +17,15 @@ add_catalog_event_lines, add_stage_timing, add_utc_time_axis, + build_component_output_payload, compute_lag, correlation_time_bounds, draw_correlation_markers, ensure_utc_datetime, get_component_selection, + make_event_output_dir, report_timing_once, + resolve_component_key, set_figure_title, shift_left_zeropad, TimingState, @@ -253,85 +256,6 @@ def plot_record_section_and_stack( return fig -def resolve_component_key(channel: str, sel_comp: str) -> str: - """Resolve storage key used in three-component aggregate output.""" - if channel == "DPZ": - return "DPZ" - if sel_comp == "R": - return "R" - if sel_comp == "T": - return "T" - return channel - - -def build_component_output_payload( - record_fig, - selected_rows: list, - rejected_rows: list, - stack_vec: np.ndarray, - t_abs: np.ndarray, - mask: np.ndarray, - sample_rate: float, - win_start: int, - win_end: int, - move_limit_sec_value: float, - move_limit_samples: int, - npts: int, - start_t: float, - end_t: float, - eve_id: str, - align_phase_name: str, - origin, - station_shifts: dict, - station_corr: dict, - calc_shifts: dict, - n_pass_window: int, - pass_window_ids: set, - snippet_by_station: dict, - ref_window: np.ndarray, - p_traveltime, - s_traveltime, - name2ll: dict, - selected_ids: set, - aligned_traces_by_station: dict, - t_ref, -): - """Create deep-copied component payload used by three-component plotting.""" - payload = { - "fig": record_fig, - "all_rows": [(r[0], r[1], r[2].copy()) for r in (selected_rows + rejected_rows)], - "stack_vec": stack_vec.copy(), - "t_abs": t_abs.copy(), - "mask": mask.copy(), - "sample_rate": sample_rate, - "win_start": win_start, - "win_end": win_end, - "move_limit_sec": move_limit_sec_value, - "move_limit_samples": move_limit_samples, - "npts": npts, - "start_time": start_t, - "end_time": end_t, - "eve_id": eve_id, - "align_phase": align_phase_name, - "origin": origin, - "station_shifts": station_shifts.copy(), - "station_corr": station_corr.copy(), - "calc_shifts": calc_shifts.copy(), - "n_pass_window": int(n_pass_window), - "pass_window_ids": sorted(list(pass_window_ids), key=lambda s: int(s)), - "snippet_by_station": {k: v.copy() for k, v in snippet_by_station.items()}, - "ref_window": ref_window.copy(), - "p_traveltime": None if p_traveltime is None else float(p_traveltime), - "s_traveltime": None if s_traveltime is None else float(s_traveltime), - "station_ll": {k: (float(v[0]), float(v[1])) for k, v in name2ll.items()}, - # Stations that passed Stage-2 screening for this component - "selected_ids": sorted(list(selected_ids), key=lambda s: int(s)), - "aligned_traces_by_station": {k: v.copy() for k, v in aligned_traces_by_station.items()}, - "t_ref": t_ref, - } - return payload - - def load_event_metadata(eve_id: str, info_dir: Path): """Load event row and return key metadata for one event id.""" eve_info = pd.read_csv(info_dir / "catalog_20220930_8events.csv") @@ -343,14 +267,6 @@ def load_event_metadata(eve_id: str, info_dir: Path): return event_depth, eve_lat, eve_lon, origin -def make_event_output_dir(base_prefix: str, eve_id: str) -> Path: - """Create and return output directory for one event.""" - save_path = Path(base_prefix + "output") - save_dir = save_path / eve_id - save_dir.mkdir(parents=True, exist_ok=True) - return save_dir - - def load_station_lookup(info_dir: Path): """Read station coordinates and return station->(lat, lon) lookup.""" station_file = info_dir / "stations.txt" diff --git a/align_utils.py b/align_utils.py index cbda487..c939f89 100644 --- a/align_utils.py +++ b/align_utils.py @@ -3,6 +3,7 @@ from datetime import timedelta, timezone from dataclasses import dataclass, field import time +from pathlib import Path import numpy as np from obspy import UTCDateTime @@ -54,6 +55,93 @@ def report_timing_once(state: TimingState) -> None: state.timing_reported = True +def resolve_component_key(channel: str, sel_comp: str) -> str: + """Resolve storage key used in three-component aggregate output.""" + if channel == "DPZ": + return "DPZ" + if sel_comp == "R": + return "R" + if sel_comp == "T": + return "T" + return channel + + +def build_component_output_payload( + record_fig, + selected_rows: list, + rejected_rows: list, + stack_vec: np.ndarray, + t_abs: np.ndarray, + mask: np.ndarray, + sample_rate: float, + win_start: int, + win_end: int, + move_limit_sec_value: float, + move_limit_samples: int, + npts: int, + start_t: float, + end_t: float, + eve_id: str, + align_phase_name: str, + origin, + station_shifts: dict, + station_corr: dict, + calc_shifts: dict, + n_pass_window: int, + pass_window_ids: set, + snippet_by_station: dict, + ref_window: np.ndarray, + p_traveltime, + s_traveltime, + name2ll: dict, + selected_ids: set, + aligned_traces_by_station: dict, + t_ref, +): + """Create deep-copied component payload used by three-component plotting.""" + payload = { + "fig": record_fig, + "all_rows": [(r[0], r[1], r[2].copy()) for r in (selected_rows + rejected_rows)], + "stack_vec": stack_vec.copy(), + "t_abs": t_abs.copy(), + "mask": mask.copy(), + "sample_rate": sample_rate, + "win_start": win_start, + "win_end": win_end, + "move_limit_sec": move_limit_sec_value, + "move_limit_samples": move_limit_samples, + "npts": npts, + "start_time": start_t, + "end_time": end_t, + "eve_id": eve_id, + "align_phase": align_phase_name, + "origin": origin, + "station_shifts": station_shifts.copy(), + "station_corr": station_corr.copy(), + "calc_shifts": calc_shifts.copy(), + "n_pass_window": int(n_pass_window), + "pass_window_ids": sorted(list(pass_window_ids), key=lambda s: int(s)), + "snippet_by_station": {k: v.copy() for k, v in snippet_by_station.items()}, + "ref_window": ref_window.copy(), + "p_traveltime": None if p_traveltime is None else float(p_traveltime), + "s_traveltime": None if s_traveltime is None else float(s_traveltime), + "station_ll": {k: (float(v[0]), float(v[1])) for k, v in name2ll.items()}, + # Stations that passed Stage-2 screening for this component + "selected_ids": sorted(list(selected_ids), key=lambda s: int(s)), + "aligned_traces_by_station": {k: v.copy() for k, v in aligned_traces_by_station.items()}, + "t_ref": t_ref, + } + return payload + + +def make_event_output_dir(base_prefix: str, eve_id: str) -> Path: + """Create and return output directory for one event.""" + save_path = Path(base_prefix + "output") + save_dir = save_path / eve_id + save_dir.mkdir(parents=True, exist_ok=True) + return save_dir + + def compute_lag( ref: np.ndarray, d: np.ndarray, From 60554075860573cb7e1de0a825a49959d397d19a Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 17:08:04 -0700 Subject: [PATCH 7/9] Extract event/station metadata loaders into align_utils --- align_stack.py | 28 ++-------------------------- align_utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/align_stack.py b/align_stack.py index 23fb026..4695450 100644 --- a/align_stack.py +++ b/align_stack.py @@ -23,6 +23,8 @@ draw_correlation_markers, ensure_utc_datetime, get_component_selection, + load_event_metadata, + load_station_lookup, make_event_output_dir, report_timing_once, resolve_component_key, @@ -256,32 +258,6 @@ def plot_record_section_and_stack( return fig -def load_event_metadata(eve_id: str, info_dir: Path): - """Load event row and return key metadata for one event id.""" - eve_info = pd.read_csv(info_dir / "catalog_20220930_8events.csv") - row = eve_info.loc[eve_info["evid"] == eve_id].iloc[0] - event_depth = float(row["depth"]) - eve_lat = float(row["latitude"]) - eve_lon = float(row["longitude"]) - origin = UTCDateTime(str(row["origin_time"])) - return event_depth, eve_lat, eve_lon, origin - - -def load_station_lookup(info_dir: Path): - """Read station coordinates and return station->(lat, lon) lookup.""" - station_file = info_dir / "stations.txt" - sta_info = np.genfromtxt( - station_file, - dtype=[("name", "U10"), ("lat", "f8"), ("lon", "f8")], - usecols=(0, 1, 2), - comments="#", - ) - sta_name = np.array([s.decode() if hasattr(s, "decode") else s for s in sta_info["name"]]) - sta_lat = sta_info["lat"] - sta_lon = sta_info["lon"] - return {sta_name[i]: (sta_lat[i], sta_lon[i]) for i in range(len(sta_name))} - - def read_waveforms_for_event( eve_id: str, channel: str, diff --git a/align_utils.py b/align_utils.py index c939f89..20b5ab0 100644 --- a/align_utils.py +++ b/align_utils.py @@ -6,6 +6,7 @@ from pathlib import Path import numpy as np +import pandas as pd from obspy import UTCDateTime @@ -142,6 +143,32 @@ def make_event_output_dir(base_prefix: str, eve_id: str) -> Path: return save_dir +def load_event_metadata(eve_id: str, info_dir: Path): + """Load event row and return key metadata for one event id.""" + eve_info = pd.read_csv(info_dir / "catalog_20220930_8events.csv") + row = eve_info.loc[eve_info["evid"] == eve_id].iloc[0] + event_depth = float(row["depth"]) + eve_lat = float(row["latitude"]) + eve_lon = float(row["longitude"]) + origin = UTCDateTime(str(row["origin_time"])) + return event_depth, eve_lat, eve_lon, origin + + +def load_station_lookup(info_dir: Path): + """Read station coordinates and return station->(lat, lon) lookup.""" + station_file = info_dir / "stations.txt" + sta_info = np.genfromtxt( + station_file, + dtype=[("name", "U10"), ("lat", "f8"), ("lon", "f8")], + usecols=(0, 1, 2), + comments="#", + ) + sta_name = np.array([s.decode() if hasattr(s, "decode") else s for s in sta_info["name"]]) + sta_lat = sta_info["lat"] + sta_lon = sta_info["lon"] + return {sta_name[i]: (sta_lat[i], sta_lon[i]) for i in range(len(sta_name))} + + def compute_lag( ref: np.ndarray, d: np.ndarray, From 048477f1baa243dbc26d4e140c87879acc5f20ef Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 17:13:37 -0700 Subject: [PATCH 8/9] Extract reference-station helpers into align_utils --- align_stack.py | 61 ++------------------------------------------------ align_utils.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 59 deletions(-) diff --git a/align_stack.py b/align_stack.py index 4695450..71efd7c 100644 --- a/align_stack.py +++ b/align_stack.py @@ -28,6 +28,8 @@ make_event_output_dir, report_timing_once, resolve_component_key, + select_reference_trace, + print_reference_summary, set_figure_title, shift_left_zeropad, TimingState, @@ -375,65 +377,6 @@ def read_waveforms_for_event( return st_window, raw_limits_by_station -def select_reference_trace(st_comp: Stream, name2ll: dict): - """Pick reference station closest to array center and return (id, trace).""" - st_comp.sort(keys=["dist_km"]) - if len(st_comp) == 0: - return None, None - - station_ids = sorted({str(tr.stats.station) for tr in st_comp}) - center_lats = [name2ll[s][0] for s in station_ids if s in name2ll] - center_lons = [name2ll[s][1] for s in station_ids if s in name2ll] - - ref_station_id = None - if len(center_lats) > 0 and len(center_lons) > 0: - center_lat = float(np.mean(center_lats)) - center_lon = float(np.mean(center_lons)) - best_dist_m = None - for sid in station_ids: - if sid not in name2ll: - continue - slat, slon = name2ll[sid] - dist_m, _, _ = gps2dist_azimuth(center_lat, center_lon, slat, slon) - if best_dist_m is None or dist_m < best_dist_m: - best_dist_m = dist_m - ref_station_id = sid - - if ref_station_id is None: - ref_station_id = str(st_comp[0].stats.station) - - ref_trace = st_comp.select(station=ref_station_id) - if len(ref_trace) == 0: - ref_trace = [st_comp[0]] - ref_station_id = str(ref_trace[0].stats.station) - return ref_station_id, ref_trace[0] - - -def print_reference_summary(ref_station_id: str, ref_trace: Trace, raw_limits_by_station: dict): - """Print reference station and data-window summary.""" - ref_trace_dur = float(ref_trace.stats.npts) / float(ref_trace.stats.sampling_rate) - print( - f" Reference station (auto): {ref_station_id} (closest to array center) " - f"dist_km={ref_trace.stats.dist_km:.2f} dur_s={ref_trace_dur:.2f}" - ) - print(f" Epicentral distance ≈ {ref_trace.stats.dist_km:.1f} km") - - ref_start_dt = ensure_utc_datetime(ref_trace.stats.starttime.datetime) - ref_end_dt = ensure_utc_datetime(ref_trace.stats.endtime.datetime) - print( - "\033[32m Reference seismogram UTC window: " - f"{ref_start_dt.isoformat()} to {ref_end_dt.isoformat()}\033[0m" - ) - if ref_station_id in raw_limits_by_station: - raw_start, raw_end = raw_limits_by_station[ref_station_id] - raw_start_dt = ensure_utc_datetime(raw_start.datetime) - raw_end_dt = ensure_utc_datetime(raw_end.datetime) - print( - "\033[32m Reference read UTC window: " - f"{raw_start_dt.isoformat()} to {raw_end_dt.isoformat()}\033[0m" - ) - - def compute_phase_travel_times(model_obj, event_depth: float, ref_trace: Trace, origin_time, align_phase_name: str): """Compute P/S travel times at reference station and selected alignment phase time.""" ref_deg = float(ref_trace.stats.dist_deg) diff --git a/align_utils.py b/align_utils.py index 20b5ab0..56e0417 100644 --- a/align_utils.py +++ b/align_utils.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd +from obspy.geodetics import gps2dist_azimuth from obspy import UTCDateTime @@ -169,6 +170,65 @@ def load_station_lookup(info_dir: Path): return {sta_name[i]: (sta_lat[i], sta_lon[i]) for i in range(len(sta_name))} +def select_reference_trace(st_comp, name2ll: dict): + """Pick reference station closest to array center and return (id, trace).""" + st_comp.sort(keys=["dist_km"]) + if len(st_comp) == 0: + return None, None + + station_ids = sorted({str(tr.stats.station) for tr in st_comp}) + center_lats = [name2ll[s][0] for s in station_ids if s in name2ll] + center_lons = [name2ll[s][1] for s in station_ids if s in name2ll] + + ref_station_id = None + if len(center_lats) > 0 and len(center_lons) > 0: + center_lat = float(np.mean(center_lats)) + center_lon = float(np.mean(center_lons)) + best_dist_m = None + for sid in station_ids: + if sid not in name2ll: + continue + slat, slon = name2ll[sid] + dist_m, _, _ = gps2dist_azimuth(center_lat, center_lon, slat, slon) + if best_dist_m is None or dist_m < best_dist_m: + best_dist_m = dist_m + ref_station_id = sid + + if ref_station_id is None: + ref_station_id = str(st_comp[0].stats.station) + + ref_trace = st_comp.select(station=ref_station_id) + if len(ref_trace) == 0: + ref_trace = [st_comp[0]] + ref_station_id = str(ref_trace[0].stats.station) + return ref_station_id, ref_trace[0] + + +def print_reference_summary(ref_station_id: str, ref_trace, raw_limits_by_station: dict): + """Print reference station and data-window summary.""" + ref_trace_dur = float(ref_trace.stats.npts) / float(ref_trace.stats.sampling_rate) + print( + f" Reference station (auto): {ref_station_id} (closest to array center) " + f"dist_km={ref_trace.stats.dist_km:.2f} dur_s={ref_trace_dur:.2f}" + ) + print(f" Epicentral distance ~= {ref_trace.stats.dist_km:.1f} km") + + ref_start_dt = ensure_utc_datetime(ref_trace.stats.starttime.datetime) + ref_end_dt = ensure_utc_datetime(ref_trace.stats.endtime.datetime) + print( + "\033[32m Reference seismogram UTC window: " + f"{ref_start_dt.isoformat()} to {ref_end_dt.isoformat()}\033[0m" + ) + if ref_station_id in raw_limits_by_station: + raw_start, raw_end = raw_limits_by_station[ref_station_id] + raw_start_dt = ensure_utc_datetime(raw_start.datetime) + raw_end_dt = ensure_utc_datetime(raw_end.datetime) + print( + "\033[32m Reference read UTC window: " + f"{raw_start_dt.isoformat()} to {raw_end_dt.isoformat()}\033[0m" + ) + + def compute_lag( ref: np.ndarray, d: np.ndarray, From 58a56c38ff8012b5ef7fce1ca0d5880e3652a9d6 Mon Sep 17 00:00:00 2001 From: John Emilio Vidale Date: Sat, 30 May 2026 17:14:32 -0700 Subject: [PATCH 9/9] Extract phase travel-time helper into align_utils --- align_stack.py | 33 +-------------------------------- align_utils.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/align_stack.py b/align_stack.py index 71efd7c..fc762f2 100644 --- a/align_stack.py +++ b/align_stack.py @@ -18,6 +18,7 @@ add_stage_timing, add_utc_time_axis, build_component_output_payload, + compute_phase_travel_times, compute_lag, correlation_time_bounds, draw_correlation_markers, @@ -377,38 +378,6 @@ def read_waveforms_for_event( return st_window, raw_limits_by_station -def compute_phase_travel_times(model_obj, event_depth: float, ref_trace: Trace, origin_time, align_phase_name: str): - """Compute P/S travel times at reference station and selected alignment phase time.""" - ref_deg = float(ref_trace.stats.dist_deg) - tts = model_obj.get_travel_times( - source_depth_in_km=event_depth, - distance_in_degree=ref_deg, - phase_list=["p", "P", "s", "S"], - ) - - p_traveltime = None - s_traveltime = None - p_arrival_time = None - s_arrival_time = None - - for tt in reversed(tts): - if tt.phase.name.upper() == "P": - p_traveltime = float(tt.time) - p_arrival_time = origin_time + p_traveltime - if tt.phase.name.upper() == "S": - s_traveltime = float(tt.time) - s_arrival_time = origin_time + s_traveltime - - if align_phase_name == "P" and p_traveltime is not None: - phase_traveltime = float(p_traveltime) - elif align_phase_name == "S" and s_traveltime is not None: - phase_traveltime = float(s_traveltime) - else: - phase_traveltime = None - - return p_traveltime, s_traveltime, p_arrival_time, s_arrival_time, phase_traveltime - - def compute_alignment_products( st_comp: Stream, ref_trace: Trace, diff --git a/align_utils.py b/align_utils.py index 56e0417..037832a 100644 --- a/align_utils.py +++ b/align_utils.py @@ -229,6 +229,38 @@ def print_reference_summary(ref_station_id: str, ref_trace, raw_limits_by_statio ) +def compute_phase_travel_times(model_obj, event_depth: float, ref_trace, origin_time, align_phase_name: str): + """Compute P/S travel times at reference station and selected alignment phase time.""" + ref_deg = float(ref_trace.stats.dist_deg) + tts = model_obj.get_travel_times( + source_depth_in_km=event_depth, + distance_in_degree=ref_deg, + phase_list=["p", "P", "s", "S"], + ) + + p_traveltime = None + s_traveltime = None + p_arrival_time = None + s_arrival_time = None + + for tt in reversed(tts): + if tt.phase.name.upper() == "P": + p_traveltime = float(tt.time) + p_arrival_time = origin_time + p_traveltime + if tt.phase.name.upper() == "S": + s_traveltime = float(tt.time) + s_arrival_time = origin_time + s_traveltime + + if align_phase_name == "P" and p_traveltime is not None: + phase_traveltime = float(p_traveltime) + elif align_phase_name == "S" and s_traveltime is not None: + phase_traveltime = float(s_traveltime) + else: + phase_traveltime = None + + return p_traveltime, s_traveltime, p_arrival_time, s_arrival_time, phase_traveltime + + def compute_lag( ref: np.ndarray, d: np.ndarray,