diff --git a/CHANGELOG.md b/CHANGELOG.md index 67950e8b..5fda7fb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] +### Added +- waveform_report - New `qualang_tools.waveform_report` subpackage with enhanced waveform visualization: vertical timeline markers, configurable colors and filtering, and Plotly figures via `create_waveform_plot_with_markers` and `extract_timing_markers`. ## [0.21.1] - 2026-01-16 ### Fixed diff --git a/qualang_tools/__init__.py b/qualang_tools/__init__.py index 428bc074..8036b10f 100644 --- a/qualang_tools/__init__.py +++ b/qualang_tools/__init__.py @@ -10,6 +10,7 @@ "macros", "multi_user", "units", + "waveform_report", "external_frameworks", "callable_from_qua", "wirer", diff --git a/qualang_tools/addons/calibration/calibrations.py b/qualang_tools/addons/calibration/calibrations.py index 730bf8a9..d1b34a7f 100644 --- a/qualang_tools/addons/calibration/calibrations.py +++ b/qualang_tools/addons/calibration/calibrations.py @@ -14,7 +14,6 @@ import numpy as np from scipy import signal - available_variables = ["frequency", "amplitude", "duration"] u = unit() diff --git a/qualang_tools/analysis/discriminator.py b/qualang_tools/analysis/discriminator.py index a40fa3f0..305cabc8 100644 --- a/qualang_tools/analysis/discriminator.py +++ b/qualang_tools/analysis/discriminator.py @@ -71,8 +71,7 @@ def two_state_discriminator(Ig, Qg, Ie, Qe, b_print=True, b_plot=True): fidelity = 100 * (gg + ee) / 2 if b_print: - print( - f""" + print(f""" Fidelity Matrix: ----------------- | {gg:.3f} | {ge:.3f} | @@ -82,8 +81,7 @@ def two_state_discriminator(Ig, Qg, Ie, Qe, b_print=True, b_plot=True): IQ plane rotated by: {180 / np.pi * angle:.1f}{chr(176)} Threshold: {threshold:.3e} Fidelity: {fidelity:.1f}% - """ - ) + """) if b_plot: fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) diff --git a/qualang_tools/callable_from_qua/_callable_from_qua.py b/qualang_tools/callable_from_qua/_callable_from_qua.py index f0ef5406..5bbd83e0 100644 --- a/qualang_tools/callable_from_qua/_callable_from_qua.py +++ b/qualang_tools/callable_from_qua/_callable_from_qua.py @@ -11,7 +11,6 @@ from packaging.version import Version import qm - # TODO: Remove this if block when we drop support for qm < 1.2.2 (and move the import that is currently in the # else block to the top) if Version(qm.__version__) < Version("1.2.2"): diff --git a/qualang_tools/characterization/two_qubit_rb/two_qubit_rb/TwoQubitRBDebugger.py b/qualang_tools/characterization/two_qubit_rb/two_qubit_rb/TwoQubitRBDebugger.py index 42eab6bd..70d34ed9 100644 --- a/qualang_tools/characterization/two_qubit_rb/two_qubit_rb/TwoQubitRBDebugger.py +++ b/qualang_tools/characterization/two_qubit_rb/two_qubit_rb/TwoQubitRBDebugger.py @@ -10,7 +10,6 @@ from .TwoQubitRB import TwoQubitRb from .verification import SequenceTracker - phased_xz_command_sequences = { r"I \otimes I": [720], # Identity on both qubits r"I \otimes Z": [732], # Z on qubit 1, Identity on qubit 2 diff --git a/qualang_tools/config/server/upload.py b/qualang_tools/config/server/upload.py index 7fe8ed1e..d01243c6 100644 --- a/qualang_tools/config/server/upload.py +++ b/qualang_tools/config/server/upload.py @@ -132,22 +132,19 @@ def enter_empty_config(click): def init_edits_file(): with open(os.path.join(UPLOAD_DIRECTORY, "config_edits.py"), "w") as fp: - fp.write( - """ + fp.write(""" from qualang_tools.config import * from qualang_tools.config.server.config_editor import config_editor from config_initial import configuration setup = ConfigBuilder() -""" - ) +""") def init_empty_initial_config_file(): with open(os.path.join(UPLOAD_DIRECTORY, "config_initial.py"), "w") as fp: - fp.write( - """ + fp.write(""" configuration = { "version": 1, "controllers": {}, @@ -159,16 +156,13 @@ def init_empty_initial_config_file(): "oscillators": {}, "mixers": {}, } -""" - ) +""") def init_final_config_file(): with open(os.path.join(UPLOAD_DIRECTORY, "config_final.py"), "w") as fp: - fp.write( - """ + fp.write(""" from config_edits import configuration, setup configuration = setup.build(configuration) -""" - ) +""") diff --git a/qualang_tools/plot/__init__.py b/qualang_tools/plot/__init__.py index a0f5306a..6bad9fc4 100644 --- a/qualang_tools/plot/__init__.py +++ b/qualang_tools/plot/__init__.py @@ -1,7 +1,6 @@ from qualang_tools.plot.plot import * from qualang_tools.plot.fitting import * - __all__ = [ "interrupt_on_close", "Fit", diff --git a/qualang_tools/results/data_handler/data_folder_tools.py b/qualang_tools/results/data_handler/data_folder_tools.py index 171d2c69..0d8e721b 100644 --- a/qualang_tools/results/data_handler/data_folder_tools.py +++ b/qualang_tools/results/data_handler/data_folder_tools.py @@ -9,7 +9,6 @@ import re from datetime import datetime - __all__ = ["DEFAULT_FOLDER_PATTERN", "extract_data_folder_properties", "get_latest_data_folder", "create_data_folder"] diff --git a/qualang_tools/results/data_handler/data_handler.py b/qualang_tools/results/data_handler/data_handler.py index c518a88b..1a74ae10 100644 --- a/qualang_tools/results/data_handler/data_handler.py +++ b/qualang_tools/results/data_handler/data_handler.py @@ -13,7 +13,6 @@ get_latest_data_folder, ) - __all__ = ["save_data", "DataHandler"] NODE_FILENAME = "node.json" diff --git a/qualang_tools/results/data_handler/data_processors/data_processor.py b/qualang_tools/results/data_handler/data_processors/data_processor.py index 48f8a8c4..b10115b9 100644 --- a/qualang_tools/results/data_handler/data_processors/data_processor.py +++ b/qualang_tools/results/data_handler/data_processors/data_processor.py @@ -1,7 +1,6 @@ from abc import ABC from pathlib import Path - __all__ = ["DataProcessor"] diff --git a/qualang_tools/results/data_handler/data_processors/simulator_controller_samples_saver.py b/qualang_tools/results/data_handler/data_processors/simulator_controller_samples_saver.py index 7ec32ff3..88ce92b0 100644 --- a/qualang_tools/results/data_handler/data_processors/simulator_controller_samples_saver.py +++ b/qualang_tools/results/data_handler/data_processors/simulator_controller_samples_saver.py @@ -5,7 +5,6 @@ from .helpers import copy_nested_dict, iterate_nested_dict, update_nested_dict from .data_processor import DataProcessor - logger = logging.getLogger(__name__) diff --git a/qualang_tools/waveform_report/__init__.py b/qualang_tools/waveform_report/__init__.py new file mode 100644 index 00000000..1f7a87e4 --- /dev/null +++ b/qualang_tools/waveform_report/__init__.py @@ -0,0 +1,13 @@ +from qualang_tools.waveform_report.enhanced_report import ( + VerticalMarkerConfig, + TimingMarker, + create_waveform_plot_with_markers, + extract_timing_markers, +) + +__all__ = [ + "VerticalMarkerConfig", + "TimingMarker", + "create_waveform_plot_with_markers", + "extract_timing_markers", +] diff --git a/qualang_tools/waveform_report/enhanced_report.py b/qualang_tools/waveform_report/enhanced_report.py new file mode 100644 index 00000000..9986f322 --- /dev/null +++ b/qualang_tools/waveform_report/enhanced_report.py @@ -0,0 +1,582 @@ +from __future__ import annotations + +""" +Enhanced Waveform Report with Vertical Timeline Markers + +This module wraps the QM SDK's waveform report functionality and adds +vertical timeline markers that span all subplot rows, making it easier +to correlate timing across different output channels. + +Usage: + from qualang_tools.waveform_report import create_waveform_plot_with_markers, VerticalMarkerConfig + + fig = create_waveform_plot_with_markers( + waveform_report, + samples, + marker_config=VerticalMarkerConfig(), + plot=True, + save_dir="./output" + ) +""" + +# NOTE: The QM SDK HTML/Plotly output format is not guaranteed stable. +# Keep the parsing logic defensive and update if the SDK output changes. + +# === IMPORTS === +from collections import defaultdict +from dataclasses import dataclass +import logging +from typing import List, Dict, Any, Literal +import warnings +import plotly.graph_objects as go +import json +import re +from pathlib import Path +import shutil +import tempfile + +logger = logging.getLogger(__name__) +# Note: This module uses warnings.warn() for suppressible user-facing issues +# (e.g., SDK parsing failures) and logger for informational messages. + +__all__ = [ + "VerticalMarkerConfig", + "TimingMarker", + "create_waveform_plot_with_markers", + "extract_timing_markers", +] + +# === CONFIGURATION === +MarkerType = Literal["start", "end"] +OperationType = Literal["analog", "digital", "adc"] + + +@dataclass +class VerticalMarkerConfig: + """Configuration for vertical timeline markers. + + Attributes: + show_analog_markers: Include markers for analog waveforms + show_digital_markers: Include markers for digital waveforms + show_adc_markers: Include markers for ADC acquisitions + start_line_color: RGBA color for operation start markers + end_line_color: RGBA color for operation end markers + line_width: Width of marker lines in pixels + line_dash: Line style - "solid", "dot", "dash", "longdash", "dashdot" + min_duration_ns: Ignore operations shorter than this (nanoseconds) + merge_threshold_ns: Merge markers within this threshold (nanoseconds) + elements_to_include: Only show markers for these elements (None = all) + elements_to_exclude: Hide markers for these elements (None = none) + show_hover_info: Enable hover tooltips on markers + hover_marker_size: Size of invisible hover markers in pixels (default 15) + """ + + show_analog_markers: bool = True + show_digital_markers: bool = False + show_adc_markers: bool = True + start_line_color: str = "rgba(34, 139, 34, 0.6)" # Forest green + end_line_color: str = "rgba(178, 34, 34, 0.6)" # Firebrick red + line_width: float = 1.0 + line_dash: str = "dot" + min_duration_ns: float = 0.0 + merge_threshold_ns: float = 4.0 + elements_to_include: list[str] | None = None + elements_to_exclude: list[str] | None = None + show_hover_info: bool = True + hover_marker_size: int = 15 + hover_points_per_marker: int = 20 + + def color_for(self, marker_type: MarkerType) -> str: + return self.start_line_color if marker_type == "start" else self.end_line_color + + def should_include(self, element: str) -> bool: + if self.elements_to_include is not None: + if element not in self.elements_to_include: + return False + + if self.elements_to_exclude is not None: + if element in self.elements_to_exclude: + return False + + return True + + +# === DATA CLASSES === +_OPERATION_TYPE_LABELS = { + "analog": "Pulse", + "digital": "Digital", + "adc": "Integration", +} + + +@dataclass(frozen=True) +class TimingMarker: + """Represents a single timing marker (start or end of an operation). + + Attributes: + timestamp_ns: Time position in nanoseconds + marker_type: Either "start" or "end" + operation_type: Either "analog", "digital", or "adc" + element: Name of the quantum element + pulse_name: Name of the pulse/operation + controller: Controller name (e.g., "con1") + fem: FEM module number + output_ports: List of physical output port numbers + """ + + timestamp_ns: float + marker_type: MarkerType # "start" or "end" + operation_type: OperationType # "analog", "digital", "adc" + element: str + pulse_name: str + controller: str + fem: int + output_ports: List[int] + + @property + def hover_text(self) -> str: + """Generate hover tooltip text for this marker.""" + type_label = _OPERATION_TYPE_LABELS.get(self.operation_type, self.operation_type) + return f"{self.marker_type.capitalize()} [{type_label}]: " f"{self.pulse_name} ({self.element})" + + +def _merge_nearby_markers(markers: List[TimingMarker], threshold_ns: float) -> List[TimingMarker]: + """Merge markers that are within threshold_ns of each other. + + Markers are only merged if they have the same marker_type (start/end) + AND the same operation_type (analog/digital/adc). This prevents a + digital marker from absorbing a nearby ADC (integration) marker, which + would cause the merged result to lose the [Integration] label. + + When merged, the first marker's metadata is kept but the timestamp + is averaged. + + Args: + markers: Sorted list of timing markers + threshold_ns: Maximum time difference for merging (nanoseconds) + + Returns: + New sorted list with nearby markers merged + """ + if not markers or threshold_ns <= 0: + return markers + + # Separate by (marker_type, operation_type) so we never merge across + # different operation types (e.g. digital + adc). + groups: Dict[tuple, List[TimingMarker]] = defaultdict(list) + for m in markers: + groups[(m.marker_type, m.operation_type)].append(m) + + def merge_group(group: List[TimingMarker]) -> List[TimingMarker]: + """Merge markers within a single type group.""" + if not group: + return [] + + # Sort by timestamp + sorted_group = sorted(group, key=lambda m: m.timestamp_ns) + merged = [] + + i = 0 + while i < len(sorted_group): + # Start a new cluster with this marker + cluster = [sorted_group[i]] + j = i + 1 + + # Add all markers within threshold + while j < len(sorted_group): + if sorted_group[j].timestamp_ns - cluster[-1].timestamp_ns <= threshold_ns: + cluster.append(sorted_group[j]) + j += 1 + else: + break + + # Create merged marker using first marker's metadata + # but with averaged timestamp + avg_timestamp = sum(m.timestamp_ns for m in cluster) / len(cluster) + merged_marker = TimingMarker( + timestamp_ns=avg_timestamp, + marker_type=cluster[0].marker_type, + operation_type=cluster[0].operation_type, + element=cluster[0].element, + pulse_name=cluster[0].pulse_name, + controller=cluster[0].controller, + fem=cluster[0].fem, + output_ports=cluster[0].output_ports, + ) + merged.append(merged_marker) + + i = j + + return merged + + # Merge each group separately + result = [] + for group in groups.values(): + result.extend(merge_group(group)) + + # Re-sort by timestamp + result.sort(key=lambda m: m.timestamp_ns) + + return result + + +def _get_y_axis_range(fig: go.Figure, yaxis_ref: str) -> tuple: + """Get the (y_min, y_max) range of a y-axis for hover trace positioning. + + Attempts to read the axis range from the figure layout. If the range + is not explicitly set (auto-ranged), computes the range from the + trace data on that axis. Falls back to (-1, 1) if no data is available. + + Args: + fig: The Plotly figure to inspect. + yaxis_ref: Y-axis reference string (e.g., 'y', 'y2', 'y3'). + + Returns: + Tuple of (y_min, y_max) for the axis range. + """ + # Map 'y' -> 'yaxis', 'y2' -> 'yaxis2', etc. + layout_key = "yaxis" if yaxis_ref == "y" else f"yaxis{yaxis_ref[1:]}" + axis_obj = fig.layout[layout_key] + + if axis_obj is not None and axis_obj.range is not None: + return (axis_obj.range[0], axis_obj.range[1]) + + # Fallback: compute from trace data on this axis + y_values = [] + for trace in fig.data: + trace_yaxis = getattr(trace, "yaxis", None) or "y" + if trace_yaxis == yaxis_ref and hasattr(trace, "y") and trace.y is not None: + for v in trace.y: + if isinstance(v, (int, float)): + y_values.append(v) + + if y_values: + return (min(y_values), max(y_values)) + + return (-1, 1) + + +def _clean_pulse_name(name: str | None) -> str: + name = name or "unknown" + return name.removeprefix("OriginPulseName=") + + +def _parse_plotly_figure_from_html(html: str) -> go.Figure: + """Parse a Plotly Figure from the SDK-generated HTML.""" + plotly_call_match = re.search(r'Plotly\.newPlot\s*\(\s*["\'][^"\']+["\']\s*,\s*', html) + if not plotly_call_match: + raise RuntimeError("Could not find Plotly.newPlot call in HTML.") + + data_start = plotly_call_match.end() + decoder = json.JSONDecoder() + try: + data, data_end_idx = decoder.raw_decode(html, data_start) + separator_match = re.search(r"\s*,\s*", html[data_end_idx:]) + if not separator_match: + raise RuntimeError("Could not locate layout separator after data.") + layout_start = data_end_idx + separator_match.end() + layout, _ = decoder.raw_decode(html, layout_start) + except (json.JSONDecodeError, ValueError) as exc: + raise RuntimeError(f"Could not parse Plotly figure from HTML: {exc}") from exc + + return go.Figure(data=data, layout=layout) + + +def _load_sdk_figure(report: Any, samples: Any, controllers: List[str] | None) -> go.Figure: + """Generate and parse the SDK HTML into a Plotly Figure.""" + try: + tmpdir = tempfile.mkdtemp(prefix="enhanced_wfr_") + save_path_for_sdk = str(Path(tmpdir) / "report") + report.create_plot(samples=samples, controllers=controllers, plot=False, save_path=save_path_for_sdk) + + html_files = list(Path(tmpdir).glob("*.html")) + if not html_files: + raise RuntimeError("SDK did not generate an HTML file.") + if len(html_files) > 1: + warnings.warn(f"SDK generated multiple HTML files; using {html_files[0].name}.") + + html_content = html_files[0].read_text(encoding="utf-8") + return _parse_plotly_figure_from_html(html_content) + finally: + if "tmpdir" in locals(): + shutil.rmtree(tmpdir, ignore_errors=True) + + +def _extract_source_markers( + waveforms: List[Dict[str, Any]], + config: VerticalMarkerConfig, + *, + op_type: OperationType, + elem_field: str, + name_field: str, + ports_field: str, + time_start: str, + time_end: str | None = None, + time_length: str | None = None, +) -> List[TimingMarker]: + if (time_end is None) == (time_length is None): + raise ValueError("Exactly one of time_end or time_length must be provided.") + + markers: List[TimingMarker] = [] + for wf in waveforms: + element = wf.get(elem_field, "") + if not config.should_include(element): + continue + + start_time = wf.get(time_start, 0) + if time_end is not None: + end_time = wf.get(time_end, 0) + duration = end_time - start_time + else: + duration = wf.get(time_length, 0) + end_time = start_time + duration + + if duration < config.min_duration_ns: + continue + + pulse_name = _clean_pulse_name(wf.get(name_field)) + controller = wf.get("controller", "") + fem = wf.get("fem", 0) + ports = wf.get(ports_field, []) + + markers.append( + TimingMarker( + timestamp_ns=start_time, + marker_type="start", + operation_type=op_type, + element=element, + pulse_name=pulse_name, + controller=controller, + fem=fem, + output_ports=ports, + ) + ) + markers.append( + TimingMarker( + timestamp_ns=end_time, + marker_type="end", + operation_type=op_type, + element=element, + pulse_name=pulse_name, + controller=controller, + fem=fem, + output_ports=ports, + ) + ) + + return markers + + +# === CORE FUNCTIONS === +def extract_timing_markers(waveform_dict: Dict[str, Any], config: VerticalMarkerConfig) -> List[TimingMarker]: + """Extract timing markers from waveform report dictionary. + + Args: + waveform_dict: Output from waveform_report.to_dict() + config: Configuration for filtering markers + + Returns: + Sorted list of TimingMarker objects + + Reference - waveform_dict structure: + { + "analog_waveforms": [ + {"timestamp": 224, "length": 40, "element": "q1_xy", + "pulse_name": "OriginPulseName=x180", "controller": "con1", + "fem": 1, "output_ports": [1], ...} + ], + "digital_waveforms": [...], + "adc_acquisitions": [ + {"start_time": 500, "end_time": 1500, "quantum_element": "rr1", + "controller": "con1", "fem": 1, "adc_ports": [1, 2], ...} + ] + } + """ + markers: List[TimingMarker] = [] + + if config.show_analog_markers: + markers.extend( + _extract_source_markers( + waveform_dict.get("analog_waveforms", []), + config, + op_type="analog", + elem_field="element", + name_field="pulse_name", + ports_field="output_ports", + time_start="timestamp", + time_length="length", + ) + ) + + if config.show_digital_markers: + markers.extend( + _extract_source_markers( + waveform_dict.get("digital_waveforms", []), + config, + op_type="digital", + elem_field="element", + name_field="pulse_name", + ports_field="output_ports", + time_start="timestamp", + time_length="length", + ) + ) + + if config.show_adc_markers: + markers.extend( + _extract_source_markers( + waveform_dict.get("adc_acquisitions", []), + config, + op_type="adc", + elem_field="quantum_element", + name_field="process", + ports_field="adc_ports", + time_start="start_time", + time_end="end_time", + ) + ) + + markers.sort(key=lambda m: m.timestamp_ns) + return markers + + +def _make_marker_shapes(markers: List[TimingMarker], config: VerticalMarkerConfig) -> List[Dict[str, Any]]: + """Generate Plotly shape dictionaries for vertical lines.""" + return [ + { + "type": "line", + "xref": "x", + "yref": "paper", + "x0": marker.timestamp_ns, + "x1": marker.timestamp_ns, + "y0": 0, + "y1": 1, + "line": { + "color": config.color_for(marker.marker_type), + "width": config.line_width, + "dash": config.line_dash, + }, + **({"name": marker.hover_text} if config.show_hover_info else {}), + } + for marker in markers + ] + + +def _add_hover_traces( + fig: go.Figure, + markers: List[TimingMarker], + config: VerticalMarkerConfig, +) -> None: + """Add batched invisible scatter traces for hover tooltips on vertical markers.""" + if not config.show_hover_info or not markers: + return + + y_to_x_axis: Dict[str, str] = {} + y_axes = set() + for trace in fig.data: + yaxis_ref = getattr(trace, "yaxis", None) or "y" + xaxis_ref = getattr(trace, "xaxis", None) or "x" + y_axes.add(yaxis_ref) + if yaxis_ref not in y_to_x_axis: + y_to_x_axis[yaxis_ref] = xaxis_ref + if not y_axes: + y_axes = {"y"} + + points_per_marker = max(1, int(config.hover_points_per_marker)) + + for yaxis_ref in sorted(y_axes, key=lambda s: int(s[1:]) if len(s) > 1 else 0): + y_min, y_max = _get_y_axis_range(fig, yaxis_ref) + xaxis_ref = y_to_x_axis.get(yaxis_ref, "x") + + if points_per_marker == 1: + y_points = [(y_min + y_max) / 2] + else: + y_points = [y_min + i * (y_max - y_min) / (points_per_marker - 1) for i in range(points_per_marker)] + + expanded_x = [] + expanded_y = [] + expanded_texts = [] + expanded_colors = [] + for marker in markers: + color = config.color_for(marker.marker_type) + for y_pt in y_points: + expanded_x.append(marker.timestamp_ns) + expanded_y.append(y_pt) + expanded_texts.append(marker.hover_text) + expanded_colors.append(color) + + batch_trace = go.Scatter( + x=expanded_x, + y=expanded_y, + mode="markers", + marker=dict(size=config.hover_marker_size, opacity=0, color=expanded_colors), + hoverinfo="text", + hovertext=expanded_texts, + showlegend=False, + xaxis=xaxis_ref, + yaxis=yaxis_ref, + ) + fig.add_trace(batch_trace) + + +# === MAIN API === +def create_waveform_plot_with_markers( + waveform_report, + samples=None, + marker_config: VerticalMarkerConfig | None = None, + controllers: list[str] | None = None, + plot: bool = True, + save_dir: str | None = None, +) -> go.Figure: + """Create an enhanced waveform visualization with vertical timing markers. + + This function wraps the QM SDK's waveform report visualization and adds + vertical lines marking operation start/end times across all subplots. + + Args: + waveform_report: WaveformReport from job.get_simulated_waveform_report() + samples: SimulatorSamples from job.get_simulated_samples() (optional) + marker_config: Configuration for vertical markers (uses defaults if None) + controllers: List of controllers to include (None = all) + plot: Whether to display the plot in browser + save_dir: Directory to save the HTML file (None = don't save) + + Returns: + Plotly Figure object for further customization + """ + config = marker_config or VerticalMarkerConfig() + + waveform_dict = waveform_report.to_dict() + markers = extract_timing_markers(waveform_dict, config) + if config.merge_threshold_ns > 0: + markers = _merge_nearby_markers(markers, config.merge_threshold_ns) + + try: + fig = _load_sdk_figure(waveform_report, samples, controllers) + except Exception as exc: + warnings.warn(f"Could not load SDK HTML: {exc}") + fig = go.Figure() + fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines")) + fig.update_layout( + title="Waveform Report (marker overlay only - could not parse SDK output)", xaxis_title="Time (ns)" + ) + + shapes = _make_marker_shapes(markers, config) + existing_shapes = list(fig.layout.shapes) if fig.layout.shapes else [] + fig.update_layout(shapes=existing_shapes + shapes) + + _add_hover_traces(fig, markers, config) + + if plot: + fig.show() + + if save_dir: + save_dir_path = Path(save_dir) + save_dir_path.mkdir(parents=True, exist_ok=True) + job_id = getattr(waveform_report, "job_id", "unknown") + filename = f"enhanced_waveform_report_{job_id}.html" + output_path = save_dir_path / filename + fig.write_html(str(output_path)) + logger.info("Enhanced waveform report saved to: %s", output_path) + + return fig diff --git a/qualang_tools/wirer/connectivity/__init__.py b/qualang_tools/wirer/connectivity/__init__.py index fcba10f9..459508d0 100644 --- a/qualang_tools/wirer/connectivity/__init__.py +++ b/qualang_tools/wirer/connectivity/__init__.py @@ -1,4 +1,3 @@ from .connectivity import Connectivity - __all__ = ["Connectivity"] diff --git a/tests/waveform_report/test_enhanced_waveform_report.py b/tests/waveform_report/test_enhanced_waveform_report.py new file mode 100644 index 00000000..3bda476f --- /dev/null +++ b/tests/waveform_report/test_enhanced_waveform_report.py @@ -0,0 +1,628 @@ +""" +Replacement tests for enhanced_waveform_report.py +Run with: python -m pytest tests/waveform_report/test_enhanced_waveform_report.py -v +""" + +import json +import os +import shutil +import tempfile +import warnings + +import pytest +import plotly.graph_objects as go + +import qualang_tools.waveform_report.enhanced_report as ewr +from qualang_tools.waveform_report import ( + VerticalMarkerConfig, + TimingMarker, + create_waveform_plot_with_markers, + extract_timing_markers, +) +from qualang_tools.waveform_report.enhanced_report import ( + _merge_nearby_markers, + _make_marker_shapes, + _get_y_axis_range, + _parse_plotly_figure_from_html, + _clean_pulse_name, +) + + +def _m(**overrides): + base = { + "timestamp_ns": 100.0, + "marker_type": "start", + "operation_type": "analog", + "element": "q1_xy", + "pulse_name": "x180", + "controller": "con1", + "fem": 1, + "output_ports": [1], + } + base.update(overrides) + return TimingMarker(**base) + + +def _wf(**overrides): + base = { + "timestamp": 100, + "length": 40, + "element": "q1_xy", + "pulse_name": "OriginPulseName=x180", + "controller": "con1", + "fem": 1, + "output_ports": [1], + } + base.update(overrides) + return base + + +def _wf_dict(analog=None, digital=None, adc=None): + return { + "analog_waveforms": analog or [], + "digital_waveforms": digital or [], + "adc_acquisitions": adc or [], + } + + +def _hover_traces(fig): + """Filter a figure's traces to only invisible hover-tooltip scatter traces.""" + return [t for t in fig.data if t.hoverinfo == "text" and t.marker is not None and t.marker.opacity == 0] + + +def _run_orchestrator(report, **overrides): + """Call create_waveform_plot_with_markers with safe defaults.""" + defaults = dict( + samples=None, + marker_config=None, + controllers=None, + plot=False, + save_dir=None, + ) + defaults.update(overrides) + return create_waveform_plot_with_markers(report, **defaults) + + +class DummyReport: + """Mock WaveformReport that mimics QM SDK's create_plot save_path behavior.""" + + def __init__(self, waveform_dict, data=None, layout=None, raise_on_create=None): + self._waveform_dict = waveform_dict + self.job_id = "dummy_job" + self._data = data or [{"x": [0, 1], "y": [0, 1], "type": "scatter"}] + self._layout = layout or {"title": "Dummy Plot"} + self._raise_on_create = raise_on_create + + def to_dict(self): + return self._waveform_dict + + def create_plot(self, samples=None, controllers=None, plot=True, save_path=None): + if self._raise_on_create is not None: + raise self._raise_on_create + + if save_path is None: + return + + save_dir = os.path.dirname(save_path) + if not save_dir: + save_dir = "." + + html = ( + "
" + "" + ) + + output_path = os.path.join(save_dir, f"waveform_report_con1_{self.job_id}.html") + with open(output_path, "w", encoding="utf-8") as f: + f.write(html) + + +@pytest.fixture +def single_analog_wf_dict(): + """Waveform dict with one analog pulse (q1_xy, x180, 100-140ns).""" + return _wf_dict(analog=[_wf()]) + + +@pytest.fixture +def single_analog_report(single_analog_wf_dict): + """DummyReport wrapping a single analog pulse.""" + return DummyReport(single_analog_wf_dict) + + +def test_config_defaults(): + config = VerticalMarkerConfig() + assert config.show_analog_markers is True + assert config.show_digital_markers is False + assert config.start_line_color == "rgba(34, 139, 34, 0.6)" + assert config.end_line_color == "rgba(178, 34, 34, 0.6)" + + +@pytest.mark.parametrize( + ("element", "include", "exclude", "expected"), + [ + ("q1_xy", None, None, True), + ("q1_xy", ["q1_xy", "q2_xy"], None, True), + ("rr1", ["q1_xy", "q2_xy"], None, False), + ("q2_xy", ["q1_xy", "q2_xy"], ["q2_xy"], False), + ], +) +def test_should_include(element, include, exclude, expected): + config = VerticalMarkerConfig( + elements_to_include=include, + elements_to_exclude=exclude, + ) + assert config.should_include(element) is expected + + +def test_color_for(): + config = VerticalMarkerConfig( + start_line_color="rgba(0, 0, 255, 0.5)", + end_line_color="rgba(255, 0, 0, 0.5)", + ) + assert config.color_for("start") == "rgba(0, 0, 255, 0.5)" + assert config.color_for("end") == "rgba(255, 0, 0, 0.5)" + + +@pytest.mark.parametrize( + ("marker_type", "op_type", "element", "pulse", "expected"), + [ + ("start", "analog", "q1_xy", "x180", "Start [Pulse]: x180 (q1_xy)"), + ("end", "digital", "rr1", "ON", "End [Digital]: ON (rr1)"), + ("start", "adc", "rr1", "acquisition", "Start [Integration]: acquisition (rr1)"), + ], +) +def test_hover_text(marker_type, op_type, element, pulse, expected): + marker = _m( + marker_type=marker_type, + operation_type=op_type, + element=element, + pulse_name=pulse, + ) + assert marker.hover_text == expected + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + (None, "unknown"), + ("", "unknown"), + ("OriginPulseName=x180", "x180"), + ("x180", "x180"), + ], +) +def test_clean_pulse_name(name, expected): + assert _clean_pulse_name(name) == expected + + +def test_extract_analog(): + waveform_dict = _wf_dict(analog=[_wf()]) + + markers = extract_timing_markers(waveform_dict, VerticalMarkerConfig()) + + assert len(markers) == 2 + assert markers[0].timestamp_ns == 100 + assert markers[1].timestamp_ns == 140 + assert markers[0].pulse_name == "x180" + + +def test_extract_adc(): + waveform_dict = _wf_dict( + adc=[ + { + "start_time": 500, + "end_time": 1500, + "quantum_element": "rr1", + "process": "full", + "controller": "con1", + "fem": 1, + "adc_ports": [1, 2], + } + ] + ) + + markers = extract_timing_markers(waveform_dict, VerticalMarkerConfig()) + + assert len(markers) == 2 + assert markers[0].timestamp_ns == 500 + assert markers[0].operation_type == "adc" + assert markers[0].element == "rr1" + assert markers[0].output_ports == [1, 2] + assert markers[1].timestamp_ns == 1500 + + +def test_extract_digital_opt_in(): + waveform_dict = _wf_dict( + digital=[ + _wf( + timestamp=200, + length=100, + element="rr1", + pulse_name="ON", + output_ports=[5], + ) + ] + ) + + markers = extract_timing_markers(waveform_dict, VerticalMarkerConfig()) + assert markers == [] + + config = VerticalMarkerConfig(show_digital_markers=True) + markers = extract_timing_markers(waveform_dict, config) + assert len(markers) == 2 + assert markers[0].operation_type == "digital" + + +@pytest.mark.parametrize( + ("include", "exclude", "expected_elements"), + [ + (["q1_xy"], None, {"q1_xy"}), + (None, ["q1_xy"], {"q2_xy"}), + (["q1_xy", "q2_xy"], ["q2_xy"], {"q1_xy"}), + ], +) +def test_element_filtering(include, exclude, expected_elements): + waveform_dict = _wf_dict( + analog=[ + _wf(timestamp=100, element="q1_xy", output_ports=[1]), + _wf(timestamp=200, element="q2_xy", output_ports=[2]), + ] + ) + + config = VerticalMarkerConfig( + elements_to_include=include, + elements_to_exclude=exclude, + ) + markers = extract_timing_markers(waveform_dict, config) + + assert {m.element for m in markers} == expected_elements + + +def test_duration_filtering(): + waveform_dict = _wf_dict( + analog=[ + _wf(length=5, pulse_name="short"), + _wf(timestamp=200, pulse_name="long"), + ] + ) + + config = VerticalMarkerConfig(min_duration_ns=10.0) + markers = extract_timing_markers(waveform_dict, config) + + assert len(markers) == 2 + assert markers[0].pulse_name == "long" + + +def test_empty_input(): + assert extract_timing_markers({}, VerticalMarkerConfig()) == [] + + +def test_sorted_output(): + waveform_dict = _wf_dict( + analog=[ + _wf(timestamp=300, element="q2_xy", output_ports=[2]), + _wf(), + ] + ) + + markers = extract_timing_markers(waveform_dict, VerticalMarkerConfig()) + timestamps = [m.timestamp_ns for m in markers] + assert timestamps == sorted(timestamps) + + +def test_null_pulse_name_regression(): + waveform_dict = _wf_dict(analog=[_wf(pulse_name=None)]) + + markers = extract_timing_markers(waveform_dict, VerticalMarkerConfig()) + assert markers[0].pulse_name == "unknown" + + +def test_merge_within_threshold(): + markers = [ + _m(timestamp_ns=100, marker_type="start"), + _m(timestamp_ns=102, marker_type="start"), + _m(timestamp_ns=104, marker_type="start"), + ] + + merged = _merge_nearby_markers(markers, threshold_ns=5.0) + + assert len(merged) == 1 + assert merged[0].timestamp_ns == 102.0 + + +def test_merge_preserves_distant(): + markers = [ + _m(timestamp_ns=100, marker_type="start"), + _m(timestamp_ns=200, marker_type="start"), + ] + + merged = _merge_nearby_markers(markers, threshold_ns=5.0) + + assert len(merged) == 2 + assert [m.timestamp_ns for m in merged] == [100, 200] + + +def test_merge_groups_by_marker_type(): + markers = [ + _m(timestamp_ns=100, marker_type="start"), + _m(timestamp_ns=102, marker_type="end"), + ] + + merged = _merge_nearby_markers(markers, threshold_ns=5.0) + + assert len(merged) == 2 + assert {m.marker_type for m in merged} == {"start", "end"} + + +def test_merge_groups_by_operation_type(): + markers = [ + _m(timestamp_ns=100, marker_type="start", operation_type="analog"), + _m(timestamp_ns=102, marker_type="start", operation_type="digital"), + ] + + merged = _merge_nearby_markers(markers, threshold_ns=5.0) + + assert len(merged) == 2 + assert {m.operation_type for m in merged} == {"analog", "digital"} + + +def test_merge_preserves_metadata(): + markers = [ + _m( + timestamp_ns=100, + element="q1_xy", + pulse_name="x180", + controller="con1", + fem=1, + output_ports=[1], + ), + _m( + timestamp_ns=102, + element="q2_xy", + pulse_name="y90", + controller="con2", + fem=2, + output_ports=[2, 3], + ), + ] + + merged = _merge_nearby_markers(markers, threshold_ns=5.0) + + assert len(merged) == 1 + assert merged[0].element == "q1_xy" + assert merged[0].pulse_name == "x180" + assert merged[0].controller == "con1" + assert merged[0].fem == 1 + assert merged[0].output_ports == [1] + + +@pytest.mark.parametrize( + ("markers", "threshold_ns", "expected_len"), + [ + ([], 5.0, 0), + ([_m(timestamp_ns=100), _m(timestamp_ns=102)], 0, 2), + ([_m(timestamp_ns=100), _m(timestamp_ns=102)], -1.0, 2), + ], +) +def test_merge_edge_cases(markers, threshold_ns, expected_len): + merged = _merge_nearby_markers(markers, threshold_ns=threshold_ns) + assert len(merged) == expected_len + + +def test_shape_structure_and_colors(): + markers = [ + _m(timestamp_ns=100, marker_type="start"), + _m(timestamp_ns=140, marker_type="end"), + ] + config = VerticalMarkerConfig( + start_line_color="rgba(0, 0, 255, 0.5)", + end_line_color="rgba(255, 0, 0, 0.5)", + show_hover_info=True, + ) + shapes = _make_marker_shapes(markers, config) + + assert len(shapes) == 2 + for shape, marker in zip(shapes, markers): + assert shape["type"] == "line" + assert shape["xref"] == "x" + assert shape["yref"] == "paper" + assert shape["x0"] == marker.timestamp_ns + assert shape["x1"] == marker.timestamp_ns + assert shape["y0"] == 0 + assert shape["y1"] == 1 + assert isinstance(shape["line"], dict) + assert shape["name"] == marker.hover_text + + assert shapes[0]["line"]["color"] == "rgba(0, 0, 255, 0.5)" + assert shapes[1]["line"]["color"] == "rgba(255, 0, 0, 0.5)" + + +def test_shape_hover_name_absent_when_disabled(): + config = VerticalMarkerConfig(show_hover_info=False) + marker = _m(timestamp_ns=100, marker_type="start") + shapes = _make_marker_shapes([marker], config) + + assert "name" not in shapes[0] + + +def test_parse_valid_html(): + html = ( + "" + "" + ) + fig = _parse_plotly_figure_from_html(html) + assert isinstance(fig, go.Figure) + assert len(fig.data) == 1 + assert fig.layout.title.text == "Plot" + + +def test_parse_invalid_html(): + with pytest.raises(RuntimeError): + _parse_plotly_figure_from_html("not a plotly html") + + +def test_explicit_range(): + fig = go.Figure() + fig.update_layout(yaxis=dict(range=[-2, 2])) + + assert _get_y_axis_range(fig, "y") == (-2, 2) + + +def test_computed_from_traces(): + fig = go.Figure() + fig.add_trace(go.Scatter(x=[0, 1], y=[10, 20])) + + assert _get_y_axis_range(fig, "y") == (10, 20) + + +def test_empty_figure_default(): + fig = go.Figure() + assert _get_y_axis_range(fig, "y") == (-1, 1) + + +def test_returns_figure_with_shapes(single_analog_report): + fig = _run_orchestrator(single_analog_report) + + assert isinstance(fig, go.Figure) + assert fig.layout.shapes is not None + assert len(fig.layout.shapes) == 2 + + config = VerticalMarkerConfig() + colors = [shape["line"]["color"] for shape in fig.layout.shapes] + assert config.start_line_color in colors + assert config.end_line_color in colors + + +def test_hover_traces_added(single_analog_report): + config = VerticalMarkerConfig(show_hover_info=True) + fig = _run_orchestrator(single_analog_report, marker_config=config) + + hover_traces = _hover_traces(fig) + assert len(hover_traces) >= 1 + + +def test_hover_disabled_no_traces(single_analog_report): + config = VerticalMarkerConfig(show_hover_info=False) + fig = _run_orchestrator(single_analog_report, marker_config=config) + + hover_traces = _hover_traces(fig) + assert len(hover_traces) == 0 + + +def test_hover_correct_text(single_analog_report): + config = VerticalMarkerConfig(show_hover_info=True, merge_threshold_ns=0) + fig = _run_orchestrator(single_analog_report, marker_config=config) + + hover_traces = _hover_traces(fig) + all_texts = [] + for trace in hover_traces: + all_texts.extend(trace.hovertext) + + assert "Start [Pulse]: x180 (q1_xy)" in all_texts + assert "End [Pulse]: x180 (q1_xy)" in all_texts + + +def test_hover_multi_subplot(monkeypatch): + waveform_dict = _wf_dict(analog=[_wf()]) + data = [ + {"x": [0, 100, 200], "y": [0.1, 0.5, 0.2], "type": "scatter", "yaxis": "y"}, + {"x": [0, 100, 200], "y": [-0.3, 0.0, 0.3], "type": "scatter", "yaxis": "y2"}, + {"x": [0, 100, 200], "y": [1, 0, 1], "type": "scatter", "yaxis": "y3"}, + ] + layout = { + "title": "Dummy Plot", + "yaxis": {"domain": [0.7, 1.0]}, + "yaxis2": {"domain": [0.35, 0.65]}, + "yaxis3": {"domain": [0.0, 0.3]}, + } + waveform_report = DummyReport(waveform_dict, data=data, layout=layout) + + def _stub_load_sdk_figure(report, samples, controllers): + return go.Figure(data=data, layout=layout) + + monkeypatch.setattr(ewr, "_load_sdk_figure", _stub_load_sdk_figure) + + config = VerticalMarkerConfig(show_hover_info=True, merge_threshold_ns=0) + fig = _run_orchestrator(waveform_report, marker_config=config) + + hover_traces = _hover_traces(fig) + assert len(hover_traces) == 3 + + yaxes_used = {getattr(t, "yaxis", "y") or "y" for t in hover_traces} + assert yaxes_used == {"y", "y2", "y3"} + + for trace in hover_traces: + xax = getattr(trace, "xaxis", None) or "x" + assert xax == "x" + + +def test_hover_points_per_marker(single_analog_report): + config = VerticalMarkerConfig( + show_hover_info=True, + merge_threshold_ns=0, + hover_points_per_marker=5, + ) + fig = _run_orchestrator(single_analog_report, marker_config=config) + + hover_traces = _hover_traces(fig) + assert len(hover_traces) == 1 + assert len(hover_traces[0].x) == 10 + assert len(hover_traces[0].y) == 10 + assert min(hover_traces[0].y) == 0 + assert max(hover_traces[0].y) == 1 + + +def test_fallback_on_sdk_failure(): + waveform_report = DummyReport(_wf_dict(), raise_on_create=RuntimeError("boom")) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + fig = _run_orchestrator(waveform_report) + + assert isinstance(fig, go.Figure) + assert fig.layout.title.text == ("Waveform Report (marker overlay only - could not parse SDK output)") + assert any("Could not load SDK HTML:" in str(w.message) for w in caught) + + +def test_temp_dir_cleaned_up(monkeypatch): + original_mkdtemp = tempfile.mkdtemp + created = {} + rmtree_calls = [] + + def _mkdtemp(prefix="tmp", **kwargs): + path = original_mkdtemp(prefix=prefix) + created["path"] = path + return path + + monkeypatch.setattr(ewr.tempfile, "mkdtemp", _mkdtemp) + + def _rmtree(path, ignore_errors=False): + rmtree_calls.append(path) + try: + os.rmdir(path) + except OSError: + if not ignore_errors: + raise + + monkeypatch.setattr(ewr.shutil, "rmtree", _rmtree) + + waveform_report = DummyReport(_wf_dict()) + monkeypatch.setattr(waveform_report, "create_plot", lambda **kw: None) + + fig = _run_orchestrator(waveform_report) + + assert isinstance(fig, go.Figure) + + created_path = created.get("path") + assert created_path is not None + assert created_path in rmtree_calls + if os.path.exists(created_path): + shutil.rmtree(created_path, ignore_errors=True) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])