From cfd0b6e66dd9972a96941b26c07dade06bd8050f Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 12:07:04 -0300 Subject: [PATCH 01/10] feat: add Mermaid diagrams, transition tables, and __format__ support - Add MermaidRenderer that converts DiagramGraph IR to Mermaid stateDiagram-v2 source (compound, parallel, history, guards, etc.) - Add TransitionTableRenderer for markdown and RST table output - Add MermaidGraphMachine facade mirroring DotGraphMachine - Add __format__ to StateChart and StateMachineMetaclass supporting dot, mermaid, md/markdown, and rst format specs - Extend CLI with --format option (mermaid, md, rst) and stdout support - Add :format: option to Sphinx statemachine-diagram directive with sphinxcontrib-mermaid integration - Update docs with new sections and doctests --- docs/conf.py | 1 + docs/diagram.md | 117 ++++ docs/releases/3.1.0.md | 19 + pyproject.toml | 1 + statemachine/contrib/diagram/__init__.py | 68 +- .../contrib/diagram/renderers/mermaid.py | 225 +++++++ .../contrib/diagram/renderers/table.py | 105 +++ statemachine/contrib/diagram/sphinx_ext.py | 47 ++ statemachine/factory.py | 24 + statemachine/statemachine.py | 24 + tests/test_contrib_diagram.py | 275 ++++++++ tests/test_mermaid_renderer.py | 620 ++++++++++++++++++ tests/test_transition_table.py | 201 ++++++ uv.lock | 36 + 14 files changed, 1758 insertions(+), 5 deletions(-) create mode 100644 statemachine/contrib/diagram/renderers/mermaid.py create mode 100644 statemachine/contrib/diagram/renderers/table.py create mode 100644 tests/test_mermaid_renderer.py create mode 100644 tests/test_transition_table.py diff --git a/docs/conf.py b/docs/conf.py index 18846738..59e518b5 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -52,6 +52,7 @@ "sphinx_gallery.gen_gallery", "sphinx_copybutton", "statemachine.contrib.diagram.sphinx_ext", + "sphinxcontrib.mermaid", ] autosectionlabel_prefix_document = True diff --git a/docs/diagram.md b/docs/diagram.md index 6b09b1ba..9aa19ffc 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -110,6 +110,118 @@ send events before rendering: python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine diagram.png --events cycle cycle cycle ``` +Use `--format` to produce **Mermaid source** or a **transition table** instead +of a Graphviz image: + +```bash +# Mermaid stateDiagram-v2 +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.mmd --format mermaid + +# Markdown transition table +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.md --format md + +# RST transition table +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.rst --format rst +``` + +Use `-` as the output file to write to stdout (handy for piping): + +```bash +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine - --format mermaid +``` + + +## Text representations with `format()` + +State machines support Python's built-in `format()` protocol for quick text +output — no diagram imports needed: + +```py +>>> from tests.examples.traffic_light_machine import TrafficLightMachine +>>> sm = TrafficLightMachine() +>>> print(f"{sm:mermaid}") +stateDiagram-v2 + direction LR + state "Green" as green + state "Yellow" as yellow + state "Red" as red + [*] --> green + green --> yellow : cycle + yellow --> red : cycle + red --> green : cycle + + classDef active fill:#40E0D0,stroke:#333 + green:::active + + +>>> print(f"{sm:md}") +| State | Event | Guard | Target | +| ------ | ----- | ----- | ------ | +| Green | cycle | | Yellow | +| Yellow | cycle | | Red | +| Red | cycle | | Green | + + +``` + +Works on **classes** too (no active-state highlighting): + +```py +>>> print(f"{TrafficLightMachine:mermaid}") +stateDiagram-v2 + direction LR + state "Green" as green + state "Yellow" as yellow + state "Red" as red + [*] --> green + green --> yellow : cycle + yellow --> red : cycle + red --> green : cycle + + +``` + +Supported format specs: `dot`, `mermaid`, `md` (or `markdown`), `rst`. +An empty spec falls back to `repr()`. + +The `dot` format returns the Graphviz DOT language source (same output as +`sm._graph().to_string()`): + +```py +>>> print(f"{sm:dot}") # doctest: +ELLIPSIS +digraph TrafficLightMachine { +... +} + +``` + + +## Mermaid output + +The `MermaidGraphMachine` facade generates +[Mermaid `stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) +source text from any state machine — no external dependencies required: + +```py +>>> from statemachine.contrib.diagram import MermaidGraphMachine +>>> from tests.examples.traffic_light_machine import TrafficLightMachine +>>> print(MermaidGraphMachine(TrafficLightMachine).get_mermaid()) +stateDiagram-v2 + direction LR + state "Green" as green + state "Yellow" as yellow + state "Red" as red + [*] --> green + green --> yellow : cycle + yellow --> red : cycle + red --> green : cycle + + +``` + +Compound states, parallel regions, history pseudo-states, guards, and +active-state highlighting are all supported. + ## Sphinx directive @@ -190,6 +302,11 @@ The directive supports the same layout options as the standard `image` and : Events to send in sequence. When present, the machine is instantiated and each event is sent before rendering. +`:format:` *(string)* +: Output format. Use `mermaid` to render via + [sphinxcontrib-mermaid](https://github.com/mgaitan/sphinxcontrib-mermaid) + instead of Graphviz SVG. Default: DOT/SVG. + **Image/figure options:** `:caption:` *(string)* diff --git a/docs/releases/3.1.0.md b/docs/releases/3.1.0.md index 34c6a3f5..714450f4 100644 --- a/docs/releases/3.1.0.md +++ b/docs/releases/3.1.0.md @@ -56,6 +56,25 @@ instantiate the machine and send events before rendering, highlighting the current active state — matching the Sphinx directive's `:events:` option. See {ref}`diagram:Command line` for details. + +### Mermaid diagram support + +State machines can now be rendered as +[Mermaid `stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) +source text — no Graphviz installation required. + +Three ways to use it: + +- **`format()` / f-strings:** `f"{sm:mermaid}"`, `f"{sm:md}"`, `f"{sm:rst}"` — + works on both instances and classes. +- **CLI:** `python -m statemachine.contrib.diagram MyMachine - --format mermaid` +- **Sphinx directive:** `:format: mermaid` renders via `sphinxcontrib-mermaid`. + +A new `TransitionTableRenderer` produces markdown or RST transition tables +from the same diagram IR. See {ref}`diagram:Text representations with format()` +and {ref}`diagram:Mermaid output` for details. + + ### Bugfixes in 3.1.0 - Fixes silent misuse of `Event()` with multiple positional arguments. Passing more than one diff --git a/pyproject.toml b/pyproject.toml index b05ccfb3..40b3b9c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dev = [ "sphinx-autobuild; python_version >'3.8'", "furo >=2024.5.6; python_version >'3.8'", "sphinx-copybutton >=0.5.2; python_version >'3.8'", + "sphinxcontrib-mermaid; python_version >'3.8'", "pdbr>=0.8.9; python_version >'3.8'", "babel >=2.16.0; python_version >='3.8'", "pytest-xdist>=3.6.1", diff --git a/statemachine/contrib/diagram/__init__.py b/statemachine/contrib/diagram/__init__.py index a2e31a71..f7ce12d6 100644 --- a/statemachine/contrib/diagram/__init__.py +++ b/statemachine/contrib/diagram/__init__.py @@ -5,6 +5,8 @@ from .extract import extract from .renderers.dot import DotRenderer from .renderers.dot import DotRendererConfig +from .renderers.mermaid import MermaidRenderer +from .renderers.mermaid import MermaidRendererConfig class DotGraphMachine: @@ -56,6 +58,32 @@ def __call__(self): return self.get_graph() +class MermaidGraphMachine: + """Facade for generating Mermaid stateDiagram-v2 source from a state machine.""" + + direction = "LR" + active_fill = "#40E0D0" + active_stroke = "#333" + + def __init__(self, machine): + self.machine = machine + + def _build_config(self) -> MermaidRendererConfig: + return MermaidRendererConfig( + direction=self.direction, + active_fill=self.active_fill, + active_stroke=self.active_stroke, + ) + + def get_mermaid(self) -> str: + ir = extract(self.machine) + renderer = MermaidRenderer(config=self._build_config()) + return renderer.render(ir) + + def __call__(self) -> str: + return self.get_mermaid() + + def quickchart_write_svg(sm, path: str): """ If the default dependency of GraphViz installed locally doesn't work for you. As an option, @@ -135,7 +163,7 @@ def import_sm(qualname): return smclass -def write_image(qualname, out, events=None): +def write_image(qualname, out, events=None, fmt=None): """ Given a `qualname`, that is the fully qualified dotted path to a StateMachine classes, imports the class and generates a dot graph using the `pydot` lib. @@ -146,7 +174,13 @@ def write_image(qualname, out, events=None): If `events` is provided, the machine is instantiated and each event is sent before rendering, so the diagram highlights the current active state. + + If `fmt` is provided, it overrides the output format: ``"mermaid"`` writes + Mermaid source text, ``"md"``/``"rst"`` write a transition table. + Use ``out="-"`` to write to stdout. """ + import sys + smclass = import_sm(qualname) if events: @@ -156,9 +190,27 @@ def write_image(qualname, out, events=None): else: machine = smclass - graph = DotGraphMachine(machine).get_graph() - out_extension = out.rsplit(".", 1)[1] - graph.write(out, format=out_extension) + if fmt in ("mermaid", "md", "rst"): + if fmt == "mermaid": + text = MermaidGraphMachine(machine).get_mermaid() + else: + from .renderers.table import TransitionTableRenderer + + ir = extract(machine) + text = TransitionTableRenderer().render(ir, fmt=fmt) + + if out == "-": + sys.stdout.write(text) + else: + with open(out, "w") as f: + f.write(text) + else: + graph = DotGraphMachine(machine).get_graph() + if out == "-": + sys.stdout.buffer.write(graph.create_svg()) # type: ignore[attr-defined] + else: + out_extension = out.rsplit(".", 1)[1] + graph.write(out, format=out_extension) def main(argv=None): @@ -180,6 +232,12 @@ def main(argv=None): nargs="+", help="Instantiate the machine and send these events before rendering.", ) + parser.add_argument( + "--format", + choices=["mermaid", "md", "rst"], + default=None, + help="Output format: mermaid source, markdown table, or RST table.", + ) args = parser.parse_args(argv) - write_image(qualname=args.class_path, out=args.out, events=args.events) + write_image(qualname=args.class_path, out=args.out, events=args.events, fmt=args.format) diff --git a/statemachine/contrib/diagram/renderers/mermaid.py b/statemachine/contrib/diagram/renderers/mermaid.py new file mode 100644 index 00000000..a2e649fb --- /dev/null +++ b/statemachine/contrib/diagram/renderers/mermaid.py @@ -0,0 +1,225 @@ +from dataclasses import dataclass +from typing import List +from typing import Optional +from typing import Set + +from ..model import ActionType +from ..model import DiagramAction +from ..model import DiagramGraph +from ..model import DiagramState +from ..model import DiagramTransition +from ..model import StateType + + +@dataclass +class MermaidRendererConfig: + """Configuration for the Mermaid renderer.""" + + direction: str = "LR" + active_fill: str = "#40E0D0" + active_stroke: str = "#333" + + +class MermaidRenderer: + """Renders a DiagramGraph into a Mermaid stateDiagram-v2 source string.""" + + def __init__(self, config: Optional[MermaidRendererConfig] = None): + self.config = config or MermaidRendererConfig() + self._active_ids: List[str] = [] + self._rendered_transitions: Set[tuple] = set() + + def render(self, graph: DiagramGraph) -> str: + """Render a DiagramGraph to a Mermaid stateDiagram-v2 string.""" + self._active_ids = [] + self._rendered_transitions = set() + + lines: List[str] = [] + lines.append("stateDiagram-v2") + lines.append(f" direction {self.config.direction}") + + top_ids = {s.id for s in graph.states} + self._render_states(graph.states, graph.transitions, lines, indent=1) + self._render_initial_and_final(graph.states, lines, indent=1) + self._render_scope_transitions(graph.transitions, top_ids, lines, indent=1) + + if self._active_ids: + cfg = self.config + lines.append("") + lines.append(f" classDef active fill:{cfg.active_fill},stroke:{cfg.active_stroke}") + for sid in self._active_ids: + lines.append(f" {sid}:::active") + + return "\n".join(lines) + "\n" + + def _render_states( + self, + states: List[DiagramState], + transitions: List[DiagramTransition], + lines: List[str], + indent: int, + ) -> None: + for state in states: + if state.type in (StateType.HISTORY_SHALLOW, StateType.HISTORY_DEEP): + label = "H*" if state.type == StateType.HISTORY_DEEP else "H" + pad = " " * indent + lines.append(f'{pad}state "{label}" as {state.id}') + continue + + if state.type == StateType.CHOICE: + pad = " " * indent + lines.append(f"{pad}state {state.id} <>") + continue + + if state.type == StateType.FORK: + pad = " " * indent + lines.append(f"{pad}state {state.id} <>") + continue + + if state.type == StateType.JOIN: + pad = " " * indent + lines.append(f"{pad}state {state.id} <>") + continue + + if state.children: + self._render_compound_state(state, transitions, lines, indent) + else: + self._render_atomic_state(state, lines, indent) + + def _render_atomic_state( + self, + state: DiagramState, + lines: List[str], + indent: int, + ) -> None: + pad = " " * indent + + if state.name != state.id: + lines.append(f'{pad}state "{state.name}" as {state.id}') + + actions = [a for a in state.actions if a.type != ActionType.INTERNAL or a.body] + if actions: + for action in actions: + lines.append(f"{pad}{state.id} : {self._format_action(action)}") + + if state.is_active: + self._active_ids.append(state.id) + + def _render_compound_state( + self, + state: DiagramState, + transitions: List[DiagramTransition], + lines: List[str], + indent: int, + ) -> None: + pad = " " * indent + + if state.type == StateType.PARALLEL: + lines.append(f'{pad}state "{state.name}" as {state.id} {{') + regions = [c for c in state.children if c.is_parallel_area or c.children] + for i, region in enumerate(regions): + if i > 0: + lines.append(f"{pad} --") + self._render_compound_state(region, transitions, lines, indent + 1) + lines.append(f"{pad}}}") + else: + label = state.name if state.name != state.id else "" + if label: + lines.append(f'{pad}state "{label}" as {state.id} {{') + else: + lines.append(f"{pad}state {state.id} {{") + + initial_child = next((c for c in state.children if c.is_initial), None) + if initial_child: + lines.append(f"{pad} [*] --> {initial_child.id}") + + self._render_states(state.children, transitions, lines, indent + 1) + + # Render transitions scoped to this compound + child_ids = self._collect_all_descendant_ids(state.children) + self._render_scope_transitions(transitions, child_ids, lines, indent + 1) + + # Final state transitions + for child in state.children: + if child.type == StateType.FINAL: + lines.append(f"{pad} {child.id} --> [*]") + + lines.append(f"{pad}}}") + + if state.is_active: + self._active_ids.append(state.id) + + def _collect_all_descendant_ids(self, states: List[DiagramState]) -> Set[str]: + """Collect all state IDs in a subtree (direct children only for scope).""" + ids: Set[str] = set() + for s in states: + ids.add(s.id) + return ids + + def _render_scope_transitions( + self, + transitions: List[DiagramTransition], + scope_ids: Set[str], + lines: List[str], + indent: int, + ) -> None: + """Render transitions where both source and all targets are in scope_ids.""" + for t in transitions: + if t.is_initial or t.is_internal: + continue + + targets = t.targets if t.targets else [t.source] + # Only render if source is in scope + if t.source not in scope_ids: + continue + # Only render if all targets are in scope + if not all(target in scope_ids for target in targets): + continue + + for target in targets: + key = (t.source, target, t.event) + if key in self._rendered_transitions: + continue + self._rendered_transitions.add(key) + self._render_single_transition(t, target, lines, indent) + + def _render_single_transition( + self, + transition: DiagramTransition, + target: str, + lines: List[str], + indent: int, + ) -> None: + pad = " " * indent + label_parts: List[str] = [] + if transition.event: + label_parts.append(transition.event) + if transition.guards: + label_parts.append(f"[{', '.join(transition.guards)}]") + + label = " ".join(label_parts) + if label: + lines.append(f"{pad}{transition.source} --> {target} : {label}") + else: + lines.append(f"{pad}{transition.source} --> {target}") + + @staticmethod + def _format_action(action: DiagramAction) -> str: + if action.type == ActionType.INTERNAL: + return action.body + return f"{action.type.value} / {action.body}" + + def _render_initial_and_final( + self, + states: List[DiagramState], + lines: List[str], + indent: int, + ) -> None: + """Render top-level [*] --> initial and final --> [*] arrows.""" + pad = " " * indent + initial = next((s for s in states if s.is_initial), None) + if initial: + lines.append(f"{pad}[*] --> {initial.id}") + + for state in states: + if state.type == StateType.FINAL: + lines.append(f"{pad}{state.id} --> [*]") diff --git a/statemachine/contrib/diagram/renderers/table.py b/statemachine/contrib/diagram/renderers/table.py new file mode 100644 index 00000000..eeaa18ec --- /dev/null +++ b/statemachine/contrib/diagram/renderers/table.py @@ -0,0 +1,105 @@ +from typing import List + +from ..model import DiagramGraph +from ..model import DiagramState +from ..model import DiagramTransition + + +class TransitionTableRenderer: + """Renders a DiagramGraph as a transition table in markdown or RST format.""" + + def render(self, graph: DiagramGraph, fmt: str = "md") -> str: + """Render the transition table. + + Args: + graph: The diagram IR to render. + fmt: Output format — ``"md"`` for markdown, ``"rst"`` for reStructuredText. + + Returns: + The formatted transition table as a string. + """ + rows = self._collect_rows(graph.states, graph.transitions) + + if fmt == "rst": + return self._render_rst(rows) + return self._render_md(rows) + + def _collect_rows( + self, + states: List[DiagramState], + transitions: List[DiagramTransition], + ) -> "List[tuple[str, str, str, str]]": + """Collect (State, Event, Guard, Target) tuples from the IR.""" + rows: List[tuple[str, str, str, str]] = [] + state_names = self._build_state_name_map(states) + + for t in transitions: + if t.is_initial or t.is_internal: + continue + + source_name = state_names.get(t.source, t.source) + guard = ", ".join(t.guards) if t.guards else "" + event = t.event or "" + + if t.targets: + for target_id in t.targets: + target_name = state_names.get(target_id, target_id) + rows.append((source_name, event, guard, target_name)) + else: + rows.append((source_name, event, guard, source_name)) + + return rows + + def _build_state_name_map(self, states: List[DiagramState]) -> dict: + """Build a mapping from state ID to display name, recursively.""" + result: dict = {} + for state in states: + result[state.id] = state.name + if state.children: + result.update(self._build_state_name_map(state.children)) + return result + + def _render_md(self, rows: "List[tuple[str, str, str, str]]") -> str: + """Render as a markdown table.""" + headers = ("State", "Event", "Guard", "Target") + col_widths = [len(h) for h in headers] + + for row in rows: + for i, cell in enumerate(row): + col_widths[i] = max(col_widths[i], len(cell)) + + def _fmt_row(cells: "tuple[str, ...]") -> str: + parts = [cell.ljust(col_widths[i]) for i, cell in enumerate(cells)] + return "| " + " | ".join(parts) + " |" + + lines = [_fmt_row(headers)] + lines.append("| " + " | ".join("-" * w for w in col_widths) + " |") + for row in rows: + lines.append(_fmt_row(row)) + + return "\n".join(lines) + "\n" + + def _render_rst(self, rows: "List[tuple[str, str, str, str]]") -> str: + """Render as an RST grid table.""" + headers = ("State", "Event", "Guard", "Target") + col_widths = [len(h) for h in headers] + + for row in rows: + for i, cell in enumerate(row): + col_widths[i] = max(col_widths[i], len(cell)) + + def _border(char: str = "-") -> str: + return "+" + "+".join(char * (w + 2) for w in col_widths) + "+" + + def _data_row(cells: "tuple[str, ...]") -> str: + parts = [f" {cell.ljust(col_widths[i])} " for i, cell in enumerate(cells)] + return "|" + "|".join(parts) + "|" + + lines = [_border("-")] + lines.append(_data_row(headers)) + lines.append(_border("=")) + for row in rows: + lines.append(_data_row(row)) + lines.append(_border("-")) + + return "\n".join(lines) + "\n" diff --git a/statemachine/contrib/diagram/sphinx_ext.py b/statemachine/contrib/diagram/sphinx_ext.py index bbc9a8ac..f512b5ac 100644 --- a/statemachine/contrib/diagram/sphinx_ext.py +++ b/statemachine/contrib/diagram/sphinx_ext.py @@ -61,6 +61,7 @@ class StateMachineDiagram(SphinxDirective): option_spec: ClassVar[dict[str, Any]] = { # State-machine options "events": directives.unchanged, + "format": directives.unchanged, # Standard image/figure options "caption": directives.unchanged, "alt": directives.unchanged, @@ -97,6 +98,11 @@ def run(self) -> list[nodes.Node]: else: machine = sm_class + output_format = self.options.get("format", "").strip().lower() + + if output_format == "mermaid": + return self._run_mermaid(machine, qualname) + try: graph = DotGraphMachine(machine).get_graph() svg_bytes: bytes = graph.create_svg() # type: ignore[attr-defined] @@ -143,6 +149,47 @@ def run(self) -> list[nodes.Node]: return [raw_node] + def _run_mermaid(self, machine: object, qualname: str) -> list[nodes.Node]: + """Render a Mermaid diagram using sphinxcontrib-mermaid's node type.""" + try: + from statemachine.contrib.diagram import MermaidGraphMachine + + mermaid_src = MermaidGraphMachine(machine).get_mermaid() + except Exception as exc: + return [ + self.state_machine.reporter.warning( + f"statemachine-diagram: failed to generate mermaid for {qualname!r}: {exc}", + line=self.lineno, + ) + ] + + try: + from sphinxcontrib.mermaid import ( # type: ignore[import-untyped] + mermaid as MermaidNode, + ) + except ImportError: + # Fallback: emit a raw code block if sphinxcontrib-mermaid is not installed + code_node = nodes.literal_block(mermaid_src, mermaid_src) + code_node["language"] = "mermaid" + return [code_node] + + node = MermaidNode() + node["code"] = mermaid_src + node["options"] = {} + + caption = self.options.get("caption") + if caption: + figure_node = nodes.figure() + figure_node += node + figure_node += nodes.caption(caption, caption) + if "name" in self.options: + self.add_name(figure_node) + return [figure_node] + + if "name" in self.options: + self.add_name(node) + return [node] + def _prepare_svg(self, svg_bytes: bytes) -> tuple[str, str, str]: """Extract the ```` element and its intrinsic dimensions.""" match = _SVG_TAG_RE.search(svg_bytes) diff --git a/statemachine/factory.py b/statemachine/factory.py index d470e3bd..3a947753 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -92,6 +92,30 @@ def __init__( cls._check() cls._setup() + def __format__(cls, fmt: str) -> str: + if fmt == "mermaid": + from .contrib.diagram import MermaidGraphMachine + + return MermaidGraphMachine(cls).get_mermaid() + elif fmt == "dot": + from .contrib.diagram import DotGraphMachine + + return DotGraphMachine(cls).get_graph().to_string() # type: ignore[no-any-return] + elif fmt in ("md", "markdown"): + from .contrib.diagram.extract import extract + from .contrib.diagram.renderers.table import TransitionTableRenderer + + return TransitionTableRenderer().render(extract(cls), fmt="md") # type: ignore[arg-type] + elif fmt == "rst": + from .contrib.diagram.extract import extract + from .contrib.diagram.renderers.table import TransitionTableRenderer + + return TransitionTableRenderer().render(extract(cls), fmt="rst") # type: ignore[arg-type] + elif fmt == "": + return repr(cls) + else: + raise ValueError(f"Unsupported format: {fmt!r}. Use 'dot', 'mermaid', 'md', or 'rst'.") + def _initials_by_document_order( # noqa: C901 cls, states: List[State], parent: "State | None" = None, order: int = 1 ): diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index c3143a84..0a4c64a7 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -239,6 +239,30 @@ def __repr__(self): f"configuration={configuration_ids!r})" ) + def __format__(self, fmt: str) -> str: + if fmt == "mermaid": + from .contrib.diagram import MermaidGraphMachine + + return MermaidGraphMachine(self).get_mermaid() + elif fmt == "dot": + from .contrib.diagram import DotGraphMachine + + return DotGraphMachine(self).get_graph().to_string() # type: ignore[no-any-return] + elif fmt in ("md", "markdown"): + from .contrib.diagram.extract import extract + from .contrib.diagram.renderers.table import TransitionTableRenderer + + return TransitionTableRenderer().render(extract(self), fmt="md") + elif fmt == "rst": + from .contrib.diagram.extract import extract + from .contrib.diagram.renderers.table import TransitionTableRenderer + + return TransitionTableRenderer().render(extract(self), fmt="rst") + elif fmt == "": + return repr(self) + else: + raise ValueError(f"Unsupported format: {fmt!r}. Use 'dot', 'mermaid', 'md', or 'rst'.") + def __getstate__(self): state = {k: v for k, v in self.__dict__.items() if not isinstance(v, InstanceState)} del state["_callbacks"] diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index 3d3a8152..23c0cabb 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -161,6 +161,109 @@ def test_generate_complain_about_module_without_sm(self, tmp_path): with pytest.raises(ValueError, match=expected_error): main(["tests.examples", str(out)]) + def test_format_mermaid(self, tmp_path): + out = tmp_path / "sm.mmd" + + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + str(out), + "--format", + "mermaid", + ] + ) + + content = out.read_text() + assert "stateDiagram-v2" in content + assert "green --> yellow : cycle" in content + + def test_format_md(self, tmp_path): + out = tmp_path / "sm.md" + + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + str(out), + "--format", + "md", + ] + ) + + content = out.read_text() + assert "| State" in content + assert "cycle" in content + + def test_format_rst(self, tmp_path): + out = tmp_path / "sm.rst" + + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + str(out), + "--format", + "rst", + ] + ) + + content = out.read_text() + assert "+---" in content + assert "cycle" in content + + def test_format_mermaid_stdout(self, capsys): + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + "-", + "--format", + "mermaid", + ] + ) + + captured = capsys.readouterr() + assert "stateDiagram-v2" in captured.out + + def test_format_md_stdout(self, capsys): + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + "-", + "--format", + "md", + ] + ) + + captured = capsys.readouterr() + assert "| State" in captured.out + + def test_stdout_default_svg(self, capsys): + """Default format to stdout writes SVG bytes.""" + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + "-", + ] + ) + + captured = capsys.readouterr() + assert " yellow : cycle" in result + + def test_format_md_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:md}" + assert "| State" in result + assert "cycle" in result + + def test_format_md_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:md}" + assert "| State" in result + + def test_format_markdown_alias(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = format(TrafficLightMachine, "markdown") + assert "| State" in result + + def test_format_rst_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:rst}" + assert "+---" in result + + def test_format_rst_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:rst}" + assert "+---" in result + + def test_format_dot_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:dot}" + assert result.startswith("digraph TrafficLightMachine {") + assert "green" in result + + def test_format_dot_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:dot}" + assert result.startswith("digraph TrafficLightMachine {") + + def test_format_empty_falls_back_to_repr(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:}" + assert "TrafficLightMachine(" in result + + def test_format_empty_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:}" + assert "TrafficLightMachine" in result + + def test_format_invalid_raises(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + with pytest.raises(ValueError, match="Unsupported format"): + f"{sm:invalid}" + + def test_format_invalid_class_raises(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + with pytest.raises(ValueError, match="Unsupported format"): + f"{TrafficLightMachine:invalid}" + + +class TestDirectiveMermaidFormat: + """Tests for the :format: mermaid Sphinx directive option.""" + + _QUALNAME = "tests.examples.traffic_light_machine.TrafficLightMachine" + + def _make_directive(self, tmp_path, options=None): + from statemachine.contrib.diagram.sphinx_ext import StateMachineDiagram + + directive = StateMachineDiagram.__new__(StateMachineDiagram) + directive.options = options or {} + directive.lineno = 1 + directive.state_machine = mock.MagicMock() + directive.state = mock.MagicMock() + directive.state.document.settings.env.app.outdir = str(tmp_path) + directive.content_offset = 0 + return directive + + def _run(self, tmp_path, qualname=None, options=None): + directive = self._make_directive(tmp_path, options=options) + directive.arguments = [qualname or self._QUALNAME] + return directive, directive.run() + + def test_mermaid_format_with_sphinxcontrib(self, tmp_path): + """When sphinxcontrib-mermaid is available, emits a mermaid node.""" + from sphinxcontrib.mermaid import mermaid as MermaidNode # type: ignore[import-untyped] + + _, result = self._run(tmp_path, options={"format": "mermaid"}) + assert len(result) == 1 + node = result[0] + assert isinstance(node, MermaidNode) + assert "stateDiagram-v2" in node["code"] + + def test_mermaid_format_with_caption(self, tmp_path): + """Mermaid format with caption wraps in figure node.""" + from sphinxcontrib.mermaid import mermaid as MermaidNode # type: ignore[import-untyped] + + _, result = self._run(tmp_path, options={"format": "mermaid", "caption": "My Diagram"}) + assert len(result) == 1 + node = result[0] + assert isinstance(node, nodes.figure) + # Figure should contain a mermaid node and a caption + mermaid_children = [c for c in node.children if isinstance(c, MermaidNode)] + assert len(mermaid_children) == 1 + caption_children = [c for c in node.children if isinstance(c, nodes.caption)] + assert len(caption_children) == 1 + assert caption_children[0].astext() == "My Diagram" + + def test_mermaid_format_fallback_no_sphinxcontrib(self, tmp_path): + """When sphinxcontrib-mermaid is not available, falls back to code block.""" + import sys + + saved = sys.modules.get("sphinxcontrib.mermaid") + sys.modules["sphinxcontrib.mermaid"] = None # type: ignore[assignment] + try: + _, result = self._run(tmp_path, options={"format": "mermaid"}) + finally: + if saved is not None: + sys.modules["sphinxcontrib.mermaid"] = saved + else: + sys.modules.pop("sphinxcontrib.mermaid", None) + + assert len(result) == 1 + node = result[0] + assert isinstance(node, nodes.literal_block) + assert "stateDiagram-v2" in node.astext() + + def test_mermaid_render_failure_returns_warning(self, tmp_path): + """Mermaid generation failure returns a warning node.""" + with mock.patch( + "statemachine.contrib.diagram.MermaidGraphMachine", + side_effect=RuntimeError("render failed"), + ): + directive, result = self._run(tmp_path, options={"format": "mermaid"}) + + assert len(result) == 1 + directive.state_machine.reporter.warning.assert_called_once() + call_args = directive.state_machine.reporter.warning.call_args + assert "failed to generate mermaid" in call_args[0][0] diff --git a/tests/test_mermaid_renderer.py b/tests/test_mermaid_renderer.py new file mode 100644 index 00000000..19562dab --- /dev/null +++ b/tests/test_mermaid_renderer.py @@ -0,0 +1,620 @@ +from statemachine.contrib.diagram import MermaidGraphMachine +from statemachine.contrib.diagram.model import ActionType +from statemachine.contrib.diagram.model import DiagramAction +from statemachine.contrib.diagram.model import DiagramGraph +from statemachine.contrib.diagram.model import DiagramState +from statemachine.contrib.diagram.model import DiagramTransition +from statemachine.contrib.diagram.model import StateType +from statemachine.contrib.diagram.renderers.mermaid import MermaidRenderer +from statemachine.contrib.diagram.renderers.mermaid import MermaidRendererConfig + +from statemachine import State +from statemachine import StateChart + + +class TestMermaidRendererSimple: + """Basic MermaidRenderer tests with simple states.""" + + def test_simple_states(self): + graph = DiagramGraph( + name="Simple", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = MermaidRenderer().render(graph) + assert "stateDiagram-v2" in result + assert "direction LR" in result + assert "[*] --> s1" in result + assert "s1 --> s2 : go" in result + + def test_initial_and_final(self): + graph = DiagramGraph( + name="InitFinal", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.FINAL), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="finish"), + ], + ) + result = MermaidRenderer().render(graph) + assert "[*] --> s1" in result + assert "s2 --> [*]" in result + + def test_custom_direction(self): + config = MermaidRendererConfig(direction="TB") + graph = DiagramGraph( + name="TB", + states=[DiagramState(id="a", name="A", type=StateType.REGULAR, is_initial=True)], + ) + result = MermaidRenderer(config=config).render(graph) + assert "direction TB" in result + + def test_state_name_differs_from_id(self): + graph = DiagramGraph( + name="Named", + states=[ + DiagramState( + id="my_state", name="My State", type=StateType.REGULAR, is_initial=True + ), + ], + ) + result = MermaidRenderer().render(graph) + assert 'state "My State" as my_state' in result + + def test_state_name_equals_id_no_declaration(self): + """When name == id, no explicit state declaration is emitted.""" + graph = DiagramGraph( + name="NoDecl", + states=[ + DiagramState(id="s1", name="s1", type=StateType.REGULAR, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert 'state "s1"' not in result + + +class TestMermaidRendererTransitions: + """Transition rendering tests.""" + + def test_transition_with_guards(self): + graph = DiagramGraph( + name="Guards", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go", guards=["is_ready"]), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s2 : go [is_ready]" in result + + def test_eventless_transition(self): + graph = DiagramGraph( + name="Eventless", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event=""), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s2\n" in result + + def test_self_transition(self): + graph = DiagramGraph( + name="SelfLoop", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s1"], event="tick"), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s1 : tick" in result + + def test_targetless_transition(self): + graph = DiagramGraph( + name="Targetless", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=[], event="tick"), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s1 : tick" in result + + def test_multi_target_transition(self): + graph = DiagramGraph( + name="Multi", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + DiagramState(id="s3", name="S3", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2", "s3"], event="split"), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s2 : split" in result + assert "s1 --> s3 : split" in result + + def test_internal_transitions_skipped(self): + graph = DiagramGraph( + name="Internal", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s1"], event="check", is_internal=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s1" not in result + + def test_initial_transitions_skipped(self): + graph = DiagramGraph( + name="InitTrans", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="", is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + # Implicit initial transitions are NOT rendered as edges + assert "s1 --> s2" not in result + + +class TestMermaidRendererActiveState: + """Active state highlighting tests.""" + + def test_active_state_class(self): + graph = DiagramGraph( + name="Active", + states=[ + DiagramState( + id="s1", name="S1", type=StateType.REGULAR, is_initial=True, is_active=True + ), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = MermaidRenderer().render(graph) + assert "classDef active" in result + assert "s1:::active" in result + assert "s2:::active" not in result + + def test_no_active_state_no_classdef(self): + graph = DiagramGraph( + name="NoActive", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "classDef" not in result + + def test_active_fill_config(self): + config = MermaidRendererConfig(active_fill="#FF0000", active_stroke="#000") + graph = DiagramGraph( + name="CustomActive", + states=[ + DiagramState( + id="s1", name="S1", type=StateType.REGULAR, is_initial=True, is_active=True + ), + ], + ) + result = MermaidRenderer(config=config).render(graph) + assert "fill:#FF0000" in result + assert "stroke:#000" in result + + +class TestMermaidRendererCompound: + """Compound and parallel state tests.""" + + def test_compound_state(self): + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child1 = State(initial=True) + child2 = State(final=True) + go = child1.to(child2) + + start = State(initial=True) + end = State(final=True) + + enter = start.to(parent) + finish = parent.to(end) + + result = MermaidGraphMachine(SM).get_mermaid() + assert 'state "Parent" as parent {' in result + assert "[*] --> child1" in result + assert "child1 --> child2 : go" in result + assert "child2 --> [*]" in result + assert "start --> parent : enter" in result + assert "parent --> end : finish" in result + + def test_compound_no_duplicate_transitions(self): + """Transitions inside compound states must not also appear at top level.""" + + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child1 = State(initial=True) + child2 = State(final=True) + go = child1.to(child2) + + start = State(initial=True) + enter = start.to(parent) + + result = MermaidGraphMachine(SM).get_mermaid() + # "child1 --> child2 : go" should appear exactly once (inside compound) + assert result.count("child1 --> child2 : go") == 1 + + def test_parallel_state(self): + class SM(StateChart): + class p(State.Parallel, name="Parallel"): + class r1(State.Compound, name="Region1"): + a = State(initial=True) + a_done = State(final=True) + finish_a = a.to(a_done) + + class r2(State.Compound, name="Region2"): + b = State(initial=True) + b_done = State(final=True) + finish_b = b.to(b_done) + + start = State(initial=True) + begin = start.to(p) + + result = MermaidGraphMachine(SM).get_mermaid() + assert 'state "Parallel" as p {' in result + assert "--" in result # parallel separator + + def test_nested_compound(self): + class SM(StateChart): + class outer(State.Compound, name="Outer"): + class inner(State.Compound, name="Inner"): + deep = State(initial=True) + deep_final = State(final=True) + go_deep = deep.to(deep_final) + + start_inner = State(initial=True) + to_inner = start_inner.to(inner) + + begin = State(initial=True) + enter = begin.to(outer) + + result = MermaidGraphMachine(SM).get_mermaid() + assert 'state "Outer" as outer {' in result + assert 'state "Inner" as inner {' in result + + +class TestMermaidRendererPseudoStates: + """Pseudo-state rendering tests.""" + + def test_history_shallow(self): + graph = DiagramGraph( + name="History", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="h", name="H", type=StateType.HISTORY_SHALLOW), + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + assert 'state "H" as h' in result + + def test_history_deep(self): + graph = DiagramGraph( + name="DeepHistory", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="h", name="H*", type=StateType.HISTORY_DEEP), + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + assert 'state "H*" as h' in result + + def test_choice_state(self): + graph = DiagramGraph( + name="Choice", + states=[ + DiagramState(id="ch", name="ch", type=StateType.CHOICE, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "state ch <>" in result + + def test_fork_state(self): + graph = DiagramGraph( + name="Fork", + states=[ + DiagramState(id="fk", name="fk", type=StateType.FORK, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "state fk <>" in result + + def test_join_state(self): + graph = DiagramGraph( + name="Join", + states=[ + DiagramState(id="jn", name="jn", type=StateType.JOIN, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "state jn <>" in result + + +class TestMermaidRendererActions: + """State action rendering tests.""" + + def test_entry_exit_actions(self): + graph = DiagramGraph( + name="Actions", + states=[ + DiagramState( + id="s1", + name="S1", + type=StateType.REGULAR, + is_initial=True, + actions=[ + DiagramAction(type=ActionType.ENTRY, body="setup"), + DiagramAction(type=ActionType.EXIT, body="cleanup"), + ], + ), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 : entry / setup" in result + assert "s1 : exit / cleanup" in result + + def test_internal_action(self): + graph = DiagramGraph( + name="InternalAction", + states=[ + DiagramState( + id="s1", + name="S1", + type=StateType.REGULAR, + is_initial=True, + actions=[ + DiagramAction(type=ActionType.INTERNAL, body="tick / handle"), + ], + ), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 : tick / handle" in result + + def test_empty_internal_action_skipped(self): + graph = DiagramGraph( + name="EmptyInternal", + states=[ + DiagramState( + id="s1", + name="S1", + type=StateType.REGULAR, + is_initial=True, + actions=[ + DiagramAction(type=ActionType.INTERNAL, body=""), + ], + ), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 : " not in result + + +class TestMermaidGraphMachine: + """Tests for the MermaidGraphMachine facade.""" + + def test_facade_returns_string(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = MermaidGraphMachine(TrafficLightMachine).get_mermaid() + assert isinstance(result, str) + assert "stateDiagram-v2" in result + + def test_facade_callable(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + facade = MermaidGraphMachine(TrafficLightMachine) + assert facade() == facade.get_mermaid() + + def test_facade_with_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = MermaidGraphMachine(sm).get_mermaid() + assert "green:::active" in result + + def test_facade_custom_config(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + class Custom(MermaidGraphMachine): + direction = "TB" + active_fill = "#FF0000" + + sm = TrafficLightMachine() + result = Custom(sm).get_mermaid() + assert "direction TB" in result + assert "fill:#FF0000" in result + + +class TestMermaidRendererEdgeCases: + """Edge case tests for coverage.""" + + def test_compound_state_name_equals_id(self): + """Compound state where name == id uses unquoted declaration.""" + graph = DiagramGraph( + name="NameId", + states=[ + DiagramState( + id="comp", + name="comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + assert "state comp {" in result + assert '"comp"' not in result + + def test_active_compound_state(self): + """Compound state that is active gets classDef.""" + graph = DiagramGraph( + name="ActiveComp", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + is_active=True, + children=[ + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + assert "comp:::active" in result + + def test_cross_scope_transition_not_in_compound(self): + """Transition crossing compound boundaries is not rendered inside the compound.""" + graph = DiagramGraph( + name="CrossScope", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + DiagramState(id="outside", name="Outside", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="c1", targets=["outside"], event="leave"), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + # c1 is inside comp, outside is at top level — the transition + # can't be rendered at either scope since source/target span scopes. + # This is expected: Mermaid doesn't support cross-scope transitions natively. + assert "c1 --> outside" not in result + + def test_no_initial_state(self): + """Graph with no initial state omits [*] arrow.""" + graph = DiagramGraph( + name="NoInitial", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR), + ], + ) + result = MermaidRenderer().render(graph) + assert "[*]" not in result + + def test_duplicate_transition_rendered_once(self): + """Duplicate transitions in the IR are rendered only once.""" + graph = DiagramGraph( + name="Dedup", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = MermaidRenderer().render(graph) + assert result.count("s1 --> s2 : go") == 1 + + def test_compound_no_initial_child(self): + """Compound state with no initial child omits internal [*] arrow.""" + graph = DiagramGraph( + name="NoInitChild", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="c1", name="C1", type=StateType.REGULAR), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + # No [*] --> c1 inside the compound + lines = result.strip().split("\n") + inner_initial = [ln for ln in lines if "[*] --> c1" in ln] + assert len(inner_initial) == 0 + + +class TestMermaidRendererIntegration: + """Integration tests with real state machines.""" + + def test_traffic_light(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = MermaidGraphMachine(TrafficLightMachine).get_mermaid() + assert "green --> yellow : cycle" in result + assert "yellow --> red : cycle" in result + assert "red --> green : cycle" in result + + def test_traffic_light_with_events(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + sm.send("cycle") + result = MermaidGraphMachine(sm).get_mermaid() + assert "yellow:::active" in result diff --git a/tests/test_transition_table.py b/tests/test_transition_table.py new file mode 100644 index 00000000..198cf495 --- /dev/null +++ b/tests/test_transition_table.py @@ -0,0 +1,201 @@ +from statemachine.contrib.diagram.extract import extract +from statemachine.contrib.diagram.model import DiagramGraph +from statemachine.contrib.diagram.model import DiagramState +from statemachine.contrib.diagram.model import DiagramTransition +from statemachine.contrib.diagram.model import StateType +from statemachine.contrib.diagram.renderers.table import TransitionTableRenderer + +from statemachine import State +from statemachine import StateChart + + +class TestTransitionTableMarkdown: + """Markdown transition table tests.""" + + def test_simple_table(self): + graph = DiagramGraph( + name="Simple", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + assert "| State" in result + assert "| Event" in result + assert "| Guard" in result + assert "| Target" in result + assert "| S1" in result + assert "go" in result + assert "| S2" in result + + def test_with_guards(self): + graph = DiagramGraph( + name="Guards", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go", guards=["is_ready"]), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + assert "is_ready" in result + + def test_multiple_targets(self): + graph = DiagramGraph( + name="Multi", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + DiagramState(id="s3", name="S3", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2", "s3"], event="split"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + lines = result.strip().split("\n") + # Header + separator + 2 data rows + assert len(lines) == 4 + + def test_skips_initial_transitions(self): + graph = DiagramGraph( + name="SkipInit", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="", is_initial=True), + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + lines = result.strip().split("\n") + # Header + separator + 1 data row (initial skipped) + assert len(lines) == 3 + + def test_skips_internal_transitions(self): + graph = DiagramGraph( + name="SkipInternal", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s1"], event="check", is_internal=True), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + lines = result.strip().split("\n") + # Header + separator only (no data rows) + assert len(lines) == 2 + + def test_targetless_transition(self): + graph = DiagramGraph( + name="Targetless", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=[], event="tick"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + assert "tick" in result + # Target falls back to source name + assert "S1" in result + + +class TestTransitionTableRST: + """RST grid table tests.""" + + def test_rst_format(self): + graph = DiagramGraph( + name="RST", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="rst") + assert "+---" in result + assert "|" in result + assert "====" in result # header separator + assert "go" in result + + def test_rst_with_guards(self): + graph = DiagramGraph( + name="RSTGuards", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go", guards=["is_ready"]), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="rst") + assert "is_ready" in result + + +class TestTransitionTableIntegration: + """Integration tests with real state machines.""" + + def test_traffic_light_md(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + ir = extract(TrafficLightMachine) + result = TransitionTableRenderer().render(ir, fmt="md") + assert "Green" in result + assert "Yellow" in result + assert "Red" in result + assert "cycle" in result + + def test_traffic_light_rst(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + ir = extract(TrafficLightMachine) + result = TransitionTableRenderer().render(ir, fmt="rst") + assert "Green" in result + assert "cycle" in result + assert "+---" in result + + def test_compound_state_names(self): + """Child state names are properly resolved.""" + + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child1 = State(initial=True) + child2 = State(final=True) + go = child1.to(child2) + + start = State(initial=True) + enter = start.to(parent) + + ir = extract(SM) + result = TransitionTableRenderer().render(ir, fmt="md") + assert "Child1" in result + assert "Child2" in result + + def test_default_format_is_md(self): + """render() without fmt defaults to markdown.""" + graph = DiagramGraph( + name="Default", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = TransitionTableRenderer().render(graph) + assert "| State" in result # markdown uses pipes diff --git a/uv.lock b/uv.lock index 4da05165..8ccdb666 100644 --- a/uv.lock +++ b/uv.lock @@ -1114,6 +1114,8 @@ dev = [ { name = "sphinx-autobuild" }, { name = "sphinx-copybutton" }, { name = "sphinx-gallery" }, + { name = "sphinxcontrib-mermaid", version = "1.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "sphinxcontrib-mermaid", version = "2.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, ] [package.metadata] @@ -1147,6 +1149,7 @@ dev = [ { name = "sphinx-autobuild", marker = "python_full_version >= '3.9'" }, { name = "sphinx-copybutton", marker = "python_full_version >= '3.9'", specifier = ">=0.5.2" }, { name = "sphinx-gallery", marker = "python_full_version >= '3.9'" }, + { name = "sphinxcontrib-mermaid", marker = "python_full_version >= '3.9'" }, ] [[package]] @@ -1389,6 +1392,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", size = 5071, upload-time = "2019-01-21T16:10:14.333Z" }, ] +[[package]] +name = "sphinxcontrib-mermaid" +version = "1.2.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "pyyaml", marker = "python_full_version < '3.10'" }, + { name = "sphinx", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/49/c6ddfe709a4ab76ac6e5a00e696f73626b2c189dc1e1965a361ec102e6cc/sphinxcontrib_mermaid-1.2.3.tar.gz", hash = "sha256:358699d0ec924ef679b41873d9edd97d0773446daf9760c75e18dc0adfd91371", size = 18885, upload-time = "2025-11-26T04:18:32.43Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/39/8b54299ffa00e597d3b0b4d042241a0a0b22cb429ad007ccfb9c1745b4d1/sphinxcontrib_mermaid-1.2.3-py3-none-any.whl", hash = "sha256:5be782b27026bef97bfb15ccb2f7868b674a1afc0982b54cb149702cfc25aa02", size = 13413, upload-time = "2025-11-26T04:18:31.269Z" }, +] + +[[package]] +name = "sphinxcontrib-mermaid" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.10'", +] +dependencies = [ + { name = "jinja2", marker = "python_full_version >= '3.10'" }, + { name = "pyyaml", marker = "python_full_version >= '3.10'" }, + { name = "sphinx", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2b/ae/999891de292919b66ea34f2c22fc22c9be90ab3536fbc0fca95716277351/sphinxcontrib_mermaid-2.0.1.tar.gz", hash = "sha256:a21a385a059a6cafd192aa3a586b14bf5c42721e229db67b459dc825d7f0a497", size = 19839, upload-time = "2026-03-05T14:10:41.901Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/46/25d64bcd7821c8d6f1080e1c43d5fcdfc442a18f759a230b5ccdc891093e/sphinxcontrib_mermaid-2.0.1-py3-none-any.whl", hash = "sha256:9dca7fbe827bad5e7e2b97c4047682cfd26e3e07398cfdc96c7a8842ae7f06e7", size = 14064, upload-time = "2026-03-05T14:10:40.533Z" }, +] + [[package]] name = "sphinxcontrib-qthelp" version = "2.0.0" From 7089a17e22049b9373423a08d3a0ff69fbf664f6 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 12:47:13 -0300 Subject: [PATCH 02/10] refactor: introduce Formatter facade with decorator-based format registry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace duplicated if/elif chains in StateChart.__format__ and StateMachineMetaclass.__format__ with a Formatter class that uses a decorator-based registry following the Open/Closed Principle. Adding a new format now requires only a decorated function — no changes to __format__, factory.py, or statemachine.py. --- statemachine/contrib/diagram/__init__.py | 20 ++-- statemachine/contrib/diagram/formatter.py | 129 ++++++++++++++++++++++ statemachine/factory.py | 25 +---- statemachine/statemachine.py | 25 +---- tests/test_contrib_diagram.py | 115 +++++++++++++++++++ 5 files changed, 257 insertions(+), 57 deletions(-) create mode 100644 statemachine/contrib/diagram/formatter.py diff --git a/statemachine/contrib/diagram/__init__.py b/statemachine/contrib/diagram/__init__.py index f7ce12d6..c6317304 100644 --- a/statemachine/contrib/diagram/__init__.py +++ b/statemachine/contrib/diagram/__init__.py @@ -3,6 +3,7 @@ from urllib.request import urlopen from .extract import extract +from .formatter import formatter as formatter from .renderers.dot import DotRenderer from .renderers.dot import DotRendererConfig from .renderers.mermaid import MermaidRenderer @@ -175,8 +176,8 @@ def write_image(qualname, out, events=None, fmt=None): If `events` is provided, the machine is instantiated and each event is sent before rendering, so the diagram highlights the current active state. - If `fmt` is provided, it overrides the output format: ``"mermaid"`` writes - Mermaid source text, ``"md"``/``"rst"`` write a transition table. + If `fmt` is provided, it overrides the output format (any registered text + format such as ``"mermaid"``, ``"dot"``, ``"md"``, ``"rst"``). Use ``out="-"`` to write to stdout. """ import sys @@ -190,15 +191,8 @@ def write_image(qualname, out, events=None, fmt=None): else: machine = smclass - if fmt in ("mermaid", "md", "rst"): - if fmt == "mermaid": - text = MermaidGraphMachine(machine).get_mermaid() - else: - from .renderers.table import TransitionTableRenderer - - ir = extract(machine) - text = TransitionTableRenderer().render(ir, fmt=fmt) - + if fmt is not None: + text = formatter.render(machine, fmt) if out == "-": sys.stdout.write(text) else: @@ -234,9 +228,9 @@ def main(argv=None): ) parser.add_argument( "--format", - choices=["mermaid", "md", "rst"], + choices=formatter.supported_formats(), default=None, - help="Output format: mermaid source, markdown table, or RST table.", + help="Output as text format instead of Graphviz image.", ) args = parser.parse_args(argv) diff --git a/statemachine/contrib/diagram/formatter.py b/statemachine/contrib/diagram/formatter.py new file mode 100644 index 00000000..3daeaf58 --- /dev/null +++ b/statemachine/contrib/diagram/formatter.py @@ -0,0 +1,129 @@ +"""Unified facade for rendering state machines in multiple text formats. + +The :class:`Formatter` class provides a decorator-based registry where each +renderer declares the format names it handles. Adding a new format only +requires writing a renderer function and decorating it — no changes to +``__format__``, ``factory.py``, or ``statemachine.py``. + +A module-level :data:`formatter` instance is the single public entry point:: + + from statemachine.contrib.diagram import formatter + + print(formatter.render(sm, "mermaid")) + + @formatter.register_format("plantuml") + def _render_plantuml(machine): + ... +""" + +from typing import TYPE_CHECKING +from typing import Callable +from typing import Dict +from typing import List + +if TYPE_CHECKING: + from typing import Union + + from statemachine.statemachine import StateChart + + MachineRef = Union["StateChart", "type[StateChart]"] + + +class Formatter: + """Unified facade for rendering state machines in multiple text formats.""" + + def __init__(self) -> None: + self._formats: Dict[str, "Callable[[MachineRef], str]"] = {} + + def register_format( + self, *names: str + ) -> "Callable[[Callable[[MachineRef], str]], Callable[[MachineRef], str]]": + """Decorator factory that registers a renderer under one or more format names. + + Usage:: + + @formatter.register_format("md", "markdown") + def _render_md(machine_or_class): + ... + """ + + def decorator( + fn: "Callable[[MachineRef], str]", + ) -> "Callable[[MachineRef], str]": + for name in names: + self._formats[name] = fn + return fn + + return decorator + + def render(self, machine_or_class: "MachineRef", fmt: str) -> str: + """Render a state machine in the given text format. + + Args: + machine_or_class: A ``StateChart`` instance or class. + fmt: Format name (e.g., ``"mermaid"``, ``"dot"``, ``"md"``). + Empty string falls back to ``repr()``. + + Raises: + ValueError: If ``fmt`` is not registered. + """ + if fmt == "": + return repr(machine_or_class) + + renderer_fn = self._formats.get(fmt) + if renderer_fn is None: + primary = sorted({self._primary_name(fn) for fn in set(self._formats.values())}) + raise ValueError( + f"Unsupported format: {fmt!r}. Use {', '.join(repr(n) for n in primary)}." + ) + return renderer_fn(machine_or_class) + + def supported_formats(self) -> List[str]: + """Return sorted list of all registered format names (including aliases).""" + return sorted(self._formats) + + def _primary_name(self, fn: "Callable[[MachineRef], str]") -> str: + """Return the first registered name for a given renderer function.""" + for name, registered_fn in self._formats.items(): + if registered_fn is fn: + return name + return "?" # pragma: no cover + + +formatter = Formatter() +"""Module-level :class:`Formatter` instance — the single public entry point.""" + + +# --------------------------------------------------------------------------- +# Built-in format registrations +# --------------------------------------------------------------------------- + + +@formatter.register_format("dot") +def _render_dot(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram import DotGraphMachine + + return DotGraphMachine(machine_or_class).get_graph().to_string() # type: ignore[no-any-return] + + +@formatter.register_format("mermaid") +def _render_mermaid(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram import MermaidGraphMachine + + return MermaidGraphMachine(machine_or_class).get_mermaid() + + +@formatter.register_format("md", "markdown") +def _render_md(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram.extract import extract + from statemachine.contrib.diagram.renderers.table import TransitionTableRenderer + + return TransitionTableRenderer().render(extract(machine_or_class), fmt="md") + + +@formatter.register_format("rst") +def _render_rst(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram.extract import extract + from statemachine.contrib.diagram.renderers.table import TransitionTableRenderer + + return TransitionTableRenderer().render(extract(machine_or_class), fmt="rst") diff --git a/statemachine/factory.py b/statemachine/factory.py index 3a947753..55da2db7 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -93,28 +93,9 @@ def __init__( cls._setup() def __format__(cls, fmt: str) -> str: - if fmt == "mermaid": - from .contrib.diagram import MermaidGraphMachine - - return MermaidGraphMachine(cls).get_mermaid() - elif fmt == "dot": - from .contrib.diagram import DotGraphMachine - - return DotGraphMachine(cls).get_graph().to_string() # type: ignore[no-any-return] - elif fmt in ("md", "markdown"): - from .contrib.diagram.extract import extract - from .contrib.diagram.renderers.table import TransitionTableRenderer - - return TransitionTableRenderer().render(extract(cls), fmt="md") # type: ignore[arg-type] - elif fmt == "rst": - from .contrib.diagram.extract import extract - from .contrib.diagram.renderers.table import TransitionTableRenderer - - return TransitionTableRenderer().render(extract(cls), fmt="rst") # type: ignore[arg-type] - elif fmt == "": - return repr(cls) - else: - raise ValueError(f"Unsupported format: {fmt!r}. Use 'dot', 'mermaid', 'md', or 'rst'.") + from .contrib.diagram.formatter import formatter + + return formatter.render(cls, fmt) # type: ignore[arg-type] def _initials_by_document_order( # noqa: C901 cls, states: List[State], parent: "State | None" = None, order: int = 1 diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index 0a4c64a7..d33ea122 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -240,28 +240,9 @@ def __repr__(self): ) def __format__(self, fmt: str) -> str: - if fmt == "mermaid": - from .contrib.diagram import MermaidGraphMachine - - return MermaidGraphMachine(self).get_mermaid() - elif fmt == "dot": - from .contrib.diagram import DotGraphMachine - - return DotGraphMachine(self).get_graph().to_string() # type: ignore[no-any-return] - elif fmt in ("md", "markdown"): - from .contrib.diagram.extract import extract - from .contrib.diagram.renderers.table import TransitionTableRenderer - - return TransitionTableRenderer().render(extract(self), fmt="md") - elif fmt == "rst": - from .contrib.diagram.extract import extract - from .contrib.diagram.renderers.table import TransitionTableRenderer - - return TransitionTableRenderer().render(extract(self), fmt="rst") - elif fmt == "": - return repr(self) - else: - raise ValueError(f"Unsupported format: {fmt!r}. Use 'dot', 'mermaid', 'md', or 'rst'.") + from .contrib.diagram.formatter import formatter + + return formatter.render(self, fmt) def __getstate__(self): state = {k: v for k, v in self.__dict__.items() if not isinstance(v, InstanceState)} diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index 23c0cabb..d8e62b7e 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -1328,6 +1328,121 @@ def test_format_invalid_class_raises(self): f"{TrafficLightMachine:invalid}" +class TestFormatter: + """Tests for the Formatter facade (render, register_format, supported_formats).""" + + def test_render_mermaid(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = formatter.render(TrafficLightMachine, "mermaid") + assert "stateDiagram-v2" in result + + def test_render_dot(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = formatter.render(TrafficLightMachine, "dot") + assert result.startswith("digraph TrafficLightMachine {") + + def test_render_md(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = formatter.render(TrafficLightMachine, "md") + assert "| State" in result + + def test_render_markdown_alias(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + assert formatter.render(TrafficLightMachine, "markdown") == formatter.render( + TrafficLightMachine, "md" + ) + + def test_render_rst(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = formatter.render(TrafficLightMachine, "rst") + assert "+---" in result + + def test_render_empty_repr_instance(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + assert formatter.render(sm, "") == repr(sm) + + def test_render_empty_repr_class(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + assert formatter.render(TrafficLightMachine, "") == repr(TrafficLightMachine) + + def test_render_invalid_raises(self): + from statemachine.contrib.diagram import formatter + + with pytest.raises(ValueError, match="Unsupported format"): + formatter.render(object(), "invalid") + + def test_supported_formats(self): + from statemachine.contrib.diagram import formatter + + fmts = formatter.supported_formats() + assert "dot" in fmts + assert "mermaid" in fmts + assert "md" in fmts + assert "markdown" in fmts + assert "rst" in fmts + + def test_register_custom_format(self): + from statemachine.contrib.diagram import formatter + + @formatter.register_format("_test_custom") + def _render_custom(machine_or_class): + return "custom output" + + try: + assert formatter.render(object(), "_test_custom") == "custom output" + finally: + formatter._formats.pop("_test_custom", None) + + def test_register_format_with_aliases(self): + from statemachine.contrib.diagram import formatter + + @formatter.register_format("_test_alias", "_test_alias2") + def _render_alias_test(machine_or_class): + return "alias output" + + try: + assert formatter.render(object(), "_test_alias") == "alias output" + assert formatter.render(object(), "_test_alias2") == "alias output" + finally: + formatter._formats.pop("_test_alias", None) + formatter._formats.pop("_test_alias2", None) + + def test_error_message_lists_primary_formats(self): + from statemachine.contrib.diagram import formatter + + with pytest.raises(ValueError, match="'dot'") as exc_info: + formatter.render(object(), "nonexistent") + msg = str(exc_info.value) + # Should list primary names, not aliases + assert "'mermaid'" in msg + assert "'md'" in msg + assert "'rst'" in msg + # "markdown" is an alias, should not appear in error message + assert "'markdown'" not in msg + + class TestDirectiveMermaidFormat: """Tests for the :format: mermaid Sphinx directive option.""" From c4f8e98b7140daf74a070684d7b1e388ecc63544 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 13:01:29 -0300 Subject: [PATCH 03/10] feat: add SVG text format and use formatter in Sphinx directive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Register "svg" format in Formatter (DOT → SVG decoded as str) - Refactor Sphinx directive to use formatter.render() for both SVG and Mermaid instead of calling DotGraphMachine/MermaidGraphMachine directly - Update _prepare_svg and _resolve_target to work with str (not bytes) --- statemachine/contrib/diagram/formatter.py | 8 +++ statemachine/contrib/diagram/sphinx_ext.py | 31 +++++----- tests/test_contrib_diagram.py | 68 +++++++++++++++------- 3 files changed, 70 insertions(+), 37 deletions(-) diff --git a/statemachine/contrib/diagram/formatter.py b/statemachine/contrib/diagram/formatter.py index 3daeaf58..0ce8a1b0 100644 --- a/statemachine/contrib/diagram/formatter.py +++ b/statemachine/contrib/diagram/formatter.py @@ -106,6 +106,14 @@ def _render_dot(machine_or_class: "MachineRef") -> str: return DotGraphMachine(machine_or_class).get_graph().to_string() # type: ignore[no-any-return] +@formatter.register_format("svg") +def _render_svg(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram import DotGraphMachine + + svg_bytes: bytes = DotGraphMachine(machine_or_class).get_graph().create_svg() # type: ignore[attr-defined] + return svg_bytes.decode("utf-8") + + @formatter.register_format("mermaid") def _render_mermaid(machine_or_class: "MachineRef") -> str: from statemachine.contrib.diagram import MermaidGraphMachine diff --git a/statemachine/contrib/diagram/sphinx_ext.py b/statemachine/contrib/diagram/sphinx_ext.py index f512b5ac..84ab50f3 100644 --- a/statemachine/contrib/diagram/sphinx_ext.py +++ b/statemachine/contrib/diagram/sphinx_ext.py @@ -39,7 +39,7 @@ def _parse_events(value: str) -> list[str]: # Match the outer ... element, stripping XML prologue/DOCTYPE. -_SVG_TAG_RE = re.compile(rb"()", re.DOTALL) +_SVG_TAG_RE = re.compile(r"()", re.DOTALL) # Match fixed width/height attributes (e.g. width="702pt" height="170pt"). _SVG_WIDTH_RE = re.compile(r'\bwidth="([^"]*(?:pt|px))"') @@ -79,7 +79,7 @@ def run(self) -> list[nodes.Node]: qualname = self.arguments[0] try: - from statemachine.contrib.diagram import DotGraphMachine + from statemachine.contrib.diagram import formatter from statemachine.contrib.diagram import import_sm sm_class = import_sm(qualname) @@ -101,11 +101,10 @@ def run(self) -> list[nodes.Node]: output_format = self.options.get("format", "").strip().lower() if output_format == "mermaid": - return self._run_mermaid(machine, qualname) + return self._run_mermaid(machine, formatter, qualname) try: - graph = DotGraphMachine(machine).get_graph() - svg_bytes: bytes = graph.create_svg() # type: ignore[attr-defined] + svg_text = formatter.render(machine, "svg") except Exception as exc: return [ self.state_machine.reporter.warning( @@ -114,12 +113,12 @@ def run(self) -> list[nodes.Node]: ) ] - svg_tag, intrinsic_width, intrinsic_height = self._prepare_svg(svg_bytes) + svg_tag, intrinsic_width, intrinsic_height = self._prepare_svg(svg_text) svg_styles = self._build_svg_styles(intrinsic_width, intrinsic_height) svg_tag = svg_tag.replace("{svg_tag}' if target: @@ -149,12 +148,10 @@ def run(self) -> list[nodes.Node]: return [raw_node] - def _run_mermaid(self, machine: object, qualname: str) -> list[nodes.Node]: + def _run_mermaid(self, machine: object, formatter: Any, qualname: str) -> list[nodes.Node]: """Render a Mermaid diagram using sphinxcontrib-mermaid's node type.""" try: - from statemachine.contrib.diagram import MermaidGraphMachine - - mermaid_src = MermaidGraphMachine(machine).get_mermaid() + mermaid_src = formatter.render(machine, "mermaid") except Exception as exc: return [ self.state_machine.reporter.warning( @@ -190,10 +187,10 @@ def _run_mermaid(self, machine: object, qualname: str) -> list[nodes.Node]: self.add_name(node) return [node] - def _prepare_svg(self, svg_bytes: bytes) -> tuple[str, str, str]: + def _prepare_svg(self, svg_text: str) -> tuple[str, str, str]: """Extract the ```` element and its intrinsic dimensions.""" - match = _SVG_TAG_RE.search(svg_bytes) - svg_tag = match.group(1).decode("utf-8") if match else svg_bytes.decode("utf-8") + match = _SVG_TAG_RE.search(svg_text) + svg_tag = match.group(1) if match else svg_text width_match = _SVG_WIDTH_RE.search(svg_tag) height_match = _SVG_HEIGHT_RE.search(svg_tag) @@ -235,7 +232,7 @@ def _build_svg_styles(self, intrinsic_width: str, intrinsic_height: str) -> str: return f'style="{"; ".join(parts)}"' - def _resolve_target(self, svg_bytes: bytes) -> str: + def _resolve_target(self, svg_text: str) -> str: """Return the href for the wrapper ```` tag, if any. When ``:target:`` is given without a value (or as empty string), the @@ -258,8 +255,8 @@ def _resolve_target(self, svg_bytes: bytes) -> str: outdir = os.path.join(self.env.app.outdir, "_images") os.makedirs(outdir, exist_ok=True) outpath = os.path.join(outdir, filename) - with open(outpath, "wb") as f: - f.write(svg_bytes) + with open(outpath, "w", encoding="utf-8") as f: + f.write(svg_text) return f"/_images/{filename}" diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index d8e62b7e..547a3141 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -889,47 +889,47 @@ def _make_directive(self, options=None): return directive def test_strips_xml_prologue(self): - svg_bytes = ( - b'\n\n' - b'' - b"" + svg_text = ( + '\n\n' + '' + "" ) directive = self._make_directive() - svg_tag, _, _ = directive._prepare_svg(svg_bytes) + svg_tag, _, _ = directive._prepare_svg(svg_text) assert not svg_tag.startswith("" in svg_tag def test_extracts_intrinsic_dimensions(self): - svg_bytes = b'' + svg_text = '' directive = self._make_directive() - _, w, h = directive._prepare_svg(svg_bytes) + _, w, h = directive._prepare_svg(svg_text) assert w == "702pt" assert h == "170pt" def test_removes_fixed_dimensions(self): - svg_bytes = b'' + svg_text = '' directive = self._make_directive() - svg_tag, _, _ = directive._prepare_svg(svg_bytes) + svg_tag, _, _ = directive._prepare_svg(svg_text) assert 'width="702pt"' not in svg_tag assert 'height="170pt"' not in svg_tag assert "viewBox" in svg_tag def test_handles_no_dimensions(self): - svg_bytes = b'' + svg_text = '' directive = self._make_directive() - _, w, h = directive._prepare_svg(svg_bytes) + _, w, h = directive._prepare_svg(svg_text) assert w == "" assert h == "" def test_handles_px_dimensions(self): - svg_bytes = b'' + svg_text = '' directive = self._make_directive() - _, w, h = directive._prepare_svg(svg_bytes) + _, w, h = directive._prepare_svg(svg_text) assert w == "200px" assert h == "100px" @@ -1030,15 +1030,15 @@ def _make_directive(self, options=None, tmp_path=None): def test_no_target_option(self): directive = self._make_directive() - assert directive._resolve_target(b"") == "" + assert directive._resolve_target("") == "" def test_explicit_target_url(self): directive = self._make_directive({"target": "https://example.com/diagram.svg"}) - assert directive._resolve_target(b"") == "https://example.com/diagram.svg" + assert directive._resolve_target("") == "https://example.com/diagram.svg" def test_empty_target_generates_file(self, tmp_path): directive = self._make_directive({"target": ""}, tmp_path=tmp_path) - svg_data = b"" + svg_data = "" result = directive._resolve_target(svg_data) assert result.startswith("/_images/statemachine-") @@ -1048,21 +1048,21 @@ def test_empty_target_generates_file(self, tmp_path): images_dir = tmp_path / "_images" svg_files = list(images_dir.glob("statemachine-*.svg")) assert len(svg_files) == 1 - assert svg_files[0].read_bytes() == svg_data + assert svg_files[0].read_text(encoding="utf-8") == svg_data def test_empty_target_deterministic_filename(self, tmp_path): """Same qualname + events produces the same filename.""" directive1 = self._make_directive({"target": "", "events": "go"}, tmp_path=tmp_path) directive2 = self._make_directive({"target": "", "events": "go"}, tmp_path=tmp_path) - result1 = directive1._resolve_target(b"1") - result2 = directive2._resolve_target(b"2") + result1 = directive1._resolve_target("1") + result2 = directive2._resolve_target("2") assert result1 == result2 def test_different_events_different_filename(self, tmp_path): """Different events produce different filenames.""" d1 = self._make_directive({"target": "", "events": "a"}, tmp_path=tmp_path) d2 = self._make_directive({"target": "", "events": "b"}, tmp_path=tmp_path) - assert d1._resolve_target(b"") != d2._resolve_target(b"") + assert d1._resolve_target("") != d2._resolve_target("") class TestDirectiveRun: @@ -1347,6 +1347,33 @@ def test_render_dot(self): result = formatter.render(TrafficLightMachine, "dot") assert result.startswith("digraph TrafficLightMachine {") + def test_render_svg(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = formatter.render(TrafficLightMachine, "svg") + assert isinstance(result, str) + assert " Date: Sun, 8 Mar 2026 14:01:12 -0300 Subject: [PATCH 04/10] feat: auto-expand {statechart:FORMAT} placeholders in class docstrings The metaclass now detects {statechart:FORMAT} placeholders in docstrings and replaces them at class definition time with the rendered output. The docstring always stays in sync with the actual states and transitions. Any registered format works: md, rst, mermaid, dot, etc. Indentation of the placeholder line is preserved in the output. --- docs/diagram.md | 67 ++++++++++++++++- docs/releases/3.1.0.md | 22 ++++++ statemachine/factory.py | 32 +++++++++ tests/machines/showcase_simple.py | 5 ++ tests/test_contrib_diagram.py | 116 ++++++++++++++++++++++++++++++ 5 files changed, 241 insertions(+), 1 deletion(-) diff --git a/docs/diagram.md b/docs/diagram.md index 9aa19ffc..83e1bbaf 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -181,7 +181,7 @@ stateDiagram-v2 ``` -Supported format specs: `dot`, `mermaid`, `md` (or `markdown`), `rst`. +Supported format specs: `dot`, `svg`, `mermaid`, `md` (or `markdown`), `rst`. An empty spec falls back to `repr()`. The `dot` format returns the Graphviz DOT language source (same output as @@ -196,6 +196,71 @@ digraph TrafficLightMachine { ``` +## Auto-expanding docstrings + +Use `{statechart:FORMAT}` placeholders in your class docstring to embed +a live representation of the state machine. The placeholder is replaced +at class definition time, so the docstring always reflects the actual +states and transitions: + +```py +>>> from statemachine.statemachine import StateChart +>>> from statemachine.state import State + +>>> class TrafficLight(StateChart): +... """A traffic light. +... +... {statechart:md} +... """ +... green = State(initial=True) +... yellow = State() +... red = State() +... cycle = green.to(yellow) | yellow.to(red) | red.to(green) + +>>> print(TrafficLight.__doc__) +A traffic light. + +| State | Event | Guard | Target | +| ------ | ----- | ----- | ------ | +| Green | cycle | | Yellow | +| Yellow | cycle | | Red | +| Red | cycle | | Green | + + + +``` + +Any registered format works: `{statechart:rst}`, `{statechart:mermaid}`, +`{statechart:dot}`, etc. + +### Choosing the right format + +| Context | Recommended format | +|---------|-------------------| +| Sphinx with RST (autodoc default) | `{statechart:rst}` | +| Sphinx with MyST Markdown | `{statechart:md}` | +| `help()` in terminal / IDE | Either works; `md` reads more cleanly | + +### Sphinx autodoc integration + +Since the placeholder is expanded at class definition time, Sphinx `autodoc` +sees the final rendered text — no extra configuration needed. + +For example, this class uses `{statechart:rst}` in its docstring: + +```{literalinclude} ../tests/machines/showcase_simple.py +:pyobject: SimpleSC +:language: python +``` + +And here is the rendered autodoc output: + +```{eval-rst} +.. autoclass:: tests.machines.showcase_simple.SimpleSC + :noindex: +``` + + ## Mermaid output The `MermaidGraphMachine` facade generates diff --git a/docs/releases/3.1.0.md b/docs/releases/3.1.0.md index 714450f4..5f08b6ed 100644 --- a/docs/releases/3.1.0.md +++ b/docs/releases/3.1.0.md @@ -75,6 +75,28 @@ from the same diagram IR. See {ref}`diagram:Text representations with format()` and {ref}`diagram:Mermaid output` for details. +### Auto-expanding docstrings + +Use `{statechart:FORMAT}` placeholders in your class docstring to embed a +live representation of the state machine. The placeholder is replaced at +class definition time, so the docstring always stays in sync with the code: + +```python +class TrafficLight(StateChart): + """A traffic light. + + {statechart:md} + """ + green = State(initial=True) + yellow = State() + red = State() + cycle = green.to(yellow) | yellow.to(red) | red.to(green) +``` + +Any registered format works: `md`, `rst`, `mermaid`, `dot`, etc. +See {ref}`diagram:Auto-expanding docstrings` for details. + + ### Bugfixes in 3.1.0 - Fixes silent misuse of `Event()` with multiple positional arguments. Passing more than one diff --git a/statemachine/factory.py b/statemachine/factory.py index 55da2db7..c29825f7 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -1,3 +1,4 @@ +import re from typing import Any from typing import Dict from typing import List @@ -91,6 +92,37 @@ def __init__( cls._check() cls._setup() + cls._expand_docstring() + + _STATECHART_RE = re.compile(r"\{statechart:(\w+)\}") + + def _expand_docstring(cls) -> None: + """Replace ``{statechart:FORMAT}`` placeholders in the class docstring.""" + doc = cls.__doc__ + if not doc: + return + + from .contrib.diagram.formatter import formatter + + def _replace(match: "re.Match[str]") -> str: + fmt = match.group(1) + rendered = formatter.render(cls, fmt) # type: ignore[arg-type] + + # Respect the indentation of the placeholder line. + line_start = doc.rfind("\n", 0, match.start()) + if line_start == -1: + indent = "" + else: + indent_match = re.match(r"[ \t]*", doc[line_start + 1 : match.start()]) + indent = indent_match.group() if indent_match else "" + + if indent: + lines = rendered.split("\n") + rendered = lines[0] + "\n" + "\n".join(indent + line for line in lines[1:]) + + return rendered + + cls.__doc__ = cls._STATECHART_RE.sub(_replace, doc) def __format__(cls, fmt: str) -> str: from .contrib.diagram.formatter import formatter diff --git a/tests/machines/showcase_simple.py b/tests/machines/showcase_simple.py index affc1ce1..ca99839d 100644 --- a/tests/machines/showcase_simple.py +++ b/tests/machines/showcase_simple.py @@ -3,6 +3,11 @@ class SimpleSC(StateChart): + """A simple three-state machine. + + {statechart:rst} + """ + idle = State(initial=True) running = State() done = State(final=True) diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index 547a3141..0505e013 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -1328,6 +1328,122 @@ def test_format_invalid_class_raises(self): f"{TrafficLightMachine:invalid}" +class TestDocstringExpansion: + """Tests for {statechart:FORMAT} placeholder expansion in docstrings.""" + + def test_md_placeholder(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """Machine. + + {statechart:md} + """ + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert "| State" in MyMachine.__doc__ + assert "{statechart:md}" not in MyMachine.__doc__ + + def test_rst_placeholder(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """Machine. + + {statechart:rst} + """ + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert "+---" in MyMachine.__doc__ + assert "{statechart:rst}" not in MyMachine.__doc__ + + def test_mermaid_placeholder(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """{statechart:mermaid}""" + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert "stateDiagram-v2" in MyMachine.__doc__ + + def test_no_placeholder_unchanged(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """Just a plain docstring.""" + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert MyMachine.__doc__ == "Just a plain docstring." + + def test_no_docstring(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert MyMachine.__doc__ is None + + def test_indentation_preserved(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + __doc__ = "Doc.\n\n Table:\n\n {statechart:md}\n\n End.\n" + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + lines = MyMachine.__doc__.split("\n") + table_lines = [line for line in lines if "|" in line] + for line in table_lines: + assert line.startswith(" |") + assert "End." in MyMachine.__doc__ + + def test_multiple_placeholders(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """MD: {statechart:md} + + Mermaid: {statechart:mermaid} + """ + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert "| State" in MyMachine.__doc__ + assert "stateDiagram-v2" in MyMachine.__doc__ + + class TestFormatter: """Tests for the Formatter facade (render, register_format, supported_formats).""" From 1e4ba199aa2afac68189bf7485a0bdcb5fad102f Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 14:23:44 -0300 Subject: [PATCH 05/10] docs: revise diagram.md for new features and narrative coherence - Reorganize into unified "Text representations" section with format table (name, aliases, description, dependencies) - Add formatter API section with render(), supported_formats(), and custom format registration example - Add live Mermaid directive example in Sphinx section - Add --format dot to CLI examples - Replace MermaidGraphMachine usage with formatter - Add autodoc integration example with SimpleSC - Add auto-expanding docstrings section with format recommendations - Update release notes --- docs/diagram.md | 209 ++++++++++++++++++++++++++++++------------------ 1 file changed, 131 insertions(+), 78 deletions(-) diff --git a/docs/diagram.md b/docs/diagram.md index 83e1bbaf..72cda42a 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -27,7 +27,6 @@ sudo apt install graphviz For other systems, see the [Graphviz downloads page](https://graphviz.org/download/). - ## Generating diagrams Every state machine instance exposes a `_graph()` method that returns a @@ -77,8 +76,7 @@ For higher resolution PNGs, set the DPI before exporting: ```python graph = sm._graph() -graph.set_dpi(300) -graph.write_png("order_control_300dpi.png") +graph.set_dpi(300).write_png("order_control_300dpi.png") ``` ```{note} @@ -89,52 +87,24 @@ complete list. ``` -## Command line - -You can generate diagrams without writing Python code: - -```bash -python -m statemachine.contrib.diagram -``` - -The output format is inferred from the file extension: - -```bash -python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine diagram.png -``` +## Text representations -To highlight the current state, use `--events` to instantiate the machine and -send events before rendering: +State machines support multiple text-based output formats, all accessible +through Python's built-in `format()` protocol, the `formatter` API, or +the command line. -```bash -python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine diagram.png --events cycle cycle cycle -``` +| Format | Aliases | Description | Dependencies | +|--------|---------|-------------|--------------| +| `mermaid` | | [Mermaid stateDiagram-v2](https://mermaid.js.org/syntax/stateDiagram.html) source | None | +| `md` | `markdown` | Transition table (pipe-delimited Markdown) | None | +| `rst` | | Transition table (RST grid table) | None | +| `dot` | | [Graphviz DOT](https://graphviz.org/doc/info/lang.html) language source | pydot | +| `svg` | | SVG markup (generated via DOT) | pydot, Graphviz | -Use `--format` to produce **Mermaid source** or a **transition table** instead -of a Graphviz image: - -```bash -# Mermaid stateDiagram-v2 -python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.mmd --format mermaid - -# Markdown transition table -python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.md --format md - -# RST transition table -python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.rst --format rst -``` - -Use `-` as the output file to write to stdout (handy for piping): - -```bash -python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine - --format mermaid -``` +### Using `format()` -## Text representations with `format()` - -State machines support Python's built-in `format()` protocol for quick text -output — no diagram imports needed: +Use f-strings or the built-in `format()` function — no diagram imports needed: ```py >>> from tests.examples.traffic_light_machine import TrafficLightMachine @@ -181,11 +151,7 @@ stateDiagram-v2 ``` -Supported format specs: `dot`, `svg`, `mermaid`, `md` (or `markdown`), `rst`. -An empty spec falls back to `repr()`. - -The `dot` format returns the Graphviz DOT language source (same output as -`sm._graph().to_string()`): +The `dot` format returns the Graphviz DOT language source: ```py >>> print(f"{sm:dot}") # doctest: +ELLIPSIS @@ -195,6 +161,101 @@ digraph TrafficLightMachine { ``` +An empty format spec (e.g., `f"{sm:}"`) falls back to `repr()`. + + +### Using the `formatter` API + +The `formatter` object is the programmatic entry point for rendering +state machines in any registered text format: + +```py +>>> from statemachine.contrib.diagram import formatter +>>> from tests.examples.traffic_light_machine import TrafficLightMachine + +>>> print(formatter.render(TrafficLightMachine, "mermaid")) +stateDiagram-v2 + direction LR + state "Green" as green + state "Yellow" as yellow + state "Red" as red + [*] --> green + green --> yellow : cycle + yellow --> red : cycle + red --> green : cycle + + +>>> formatter.supported_formats() +['dot', 'markdown', 'md', 'mermaid', 'rst', 'svg'] + +``` + +Both `format()` and the Sphinx directive delegate to this same `formatter` +under the hood. + + +#### Registering custom formats + +The `formatter` is extensible — register your own format with a +decorator and it becomes available everywhere (`format()`, CLI, +Sphinx directive): + +```python +from statemachine.contrib.diagram import formatter + +@formatter.register_format("plantuml", "puml") +def _render_plantuml(machine_or_class): + # your PlantUML renderer here + ... +``` + +After registration, `f"{sm:plantuml}"` and `--format plantuml` work +immediately. + + +### Command line + +You can generate diagrams without writing Python code: + +```bash +python -m statemachine.contrib.diagram +``` + +The output format is inferred from the file extension: + +```bash +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine diagram.png +``` + +To highlight the current state, use `--events` to instantiate the machine and +send events before rendering: + +```bash +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine diagram.png --events cycle cycle cycle +``` + +Use `--format` to produce a text format instead of a Graphviz image: + +```bash +# Mermaid stateDiagram-v2 +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.mmd --format mermaid + +# DOT source +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.dot --format dot + +# Markdown transition table +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.md --format md + +# RST transition table +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.rst --format rst +``` + +Use `-` as the output file to write to stdout (handy for piping): + +```bash +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine - --format mermaid +``` + ## Auto-expanding docstrings @@ -261,33 +322,6 @@ And here is the rendered autodoc output: ``` -## Mermaid output - -The `MermaidGraphMachine` facade generates -[Mermaid `stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) -source text from any state machine — no external dependencies required: - -```py ->>> from statemachine.contrib.diagram import MermaidGraphMachine ->>> from tests.examples.traffic_light_machine import TrafficLightMachine ->>> print(MermaidGraphMachine(TrafficLightMachine).get_mermaid()) -stateDiagram-v2 - direction LR - state "Green" as green - state "Yellow" as yellow - state "Red" as red - [*] --> green - green --> yellow : cycle - yellow --> red : cycle - red --> green : cycle - - -``` - -Compound states, parallel regions, history pseudo-states, guards, and -active-state highlighting are all supported. - - ## Sphinx directive If you use [Sphinx](https://www.sphinx-doc.org/) to build your documentation, the @@ -356,6 +390,26 @@ zoom and pan freely: :align: center ``` +### Mermaid format + +Use `:format: mermaid` to render via +[sphinxcontrib-mermaid](https://github.com/mgaitan/sphinxcontrib-mermaid) +instead of Graphviz SVG — useful when you don't want to install Graphviz +in your docs build environment: + +````markdown +```{statemachine-diagram} myproject.machines.TrafficLight +:format: mermaid +:caption: Rendered as Mermaid +``` +```` + +```{statemachine-diagram} tests.examples.traffic_light_machine.TrafficLightMachine +:format: mermaid +:caption: TrafficLightMachine (Mermaid) +:align: center +``` + ### Directive options The directive supports the same layout options as the standard `image` and @@ -368,8 +422,7 @@ The directive supports the same layout options as the standard `image` and each event is sent before rendering. `:format:` *(string)* -: Output format. Use `mermaid` to render via - [sphinxcontrib-mermaid](https://github.com/mgaitan/sphinxcontrib-mermaid) +: Output format. Use `mermaid` to render via sphinxcontrib-mermaid instead of Graphviz SVG. Default: DOT/SVG. **Image/figure options:** From 4edee562a52100a074ae40fea23c6333deae04b7 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 14:38:54 -0300 Subject: [PATCH 06/10] docs: mention f-string/format() text representations in README and tutorial --- README.md | 14 ++++++++++++-- docs/tutorial.md | 45 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f3af4ecb..43f141a6 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,17 @@ True ``` -Generate a diagram: +Generate a diagram or get a text representation with f-strings: + +```py +>>> print(f"{sm:md}") +| State | Event | Guard | Target | +| ------ | ----- | ----- | ------ | +| Green | cycle | | Yellow | +| Yellow | cycle | | Red | +| Red | cycle | | Green | + +``` ```python sm._graph().write_png("traffic_light.png") @@ -341,7 +351,7 @@ There's a lot more to explore: - **`prepare_event`** callback — inject custom data into all callbacks - **Observer pattern** — register external listeners to watch events and state changes - **Django integration** — auto-discover state machines in Django apps with `MachineMixin` -- **Diagram generation** — from the CLI, at runtime, or in Jupyter notebooks +- **Diagram generation** — via f-strings (`f"{sm:mermaid}"`), CLI, Sphinx directive, or Jupyter - **Dictionary-based definitions** — create state machines from data structures - **Internationalization** — error messages in multiple languages diff --git a/docs/tutorial.md b/docs/tutorial.md index e6b23d3b..d49526d7 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -364,16 +364,55 @@ Or from the command line: python -m statemachine.contrib.diagram my_module.CoffeeOrder order.png ``` +### Text representations with `format()` + +You can also get text representations of any state machine using Python's built-in +`format()` or f-strings — no Graphviz needed: + +```py +>>> from tests.machines.tutorial_coffee_order import CoffeeOrder + +>>> print(f"{CoffeeOrder:md}") +| State | Event | Guard | Target | +| --------- | ------- | ----- | --------- | +| Pending | start | | Preparing | +| Preparing | finish | | Ready | +| Ready | pick_up | | Picked up | + +``` + +Supported formats include `mermaid`, `md` (markdown table), `rst`, `dot`, and `svg`. +Works on both classes and instances: + +```py +>>> print(f"{CoffeeOrder:mermaid}") +stateDiagram-v2 + direction LR + state "Pending" as pending + state "Preparing" as preparing + state "Ready" as ready + state "Picked up" as picked_up + [*] --> pending + picked_up --> [*] + pending --> preparing : start + preparing --> ready : finish + ready --> picked_up : pick_up + + +``` + ```{tip} -Diagram generation requires [Graphviz](https://graphviz.org/) (`dot` command) +Graphviz diagram generation requires [Graphviz](https://graphviz.org/) (`dot` command) and the `diagrams` extra: pip install python-statemachine[diagrams] + +Text formats (`md`, `rst`, `mermaid`) work without any extra dependencies. ``` ```{seealso} -See [](diagram.md) for highlighting active states, Jupyter integration, -SVG output, DPI settings, Sphinx directive, and the `quickchart_write_svg` +See [](diagram.md) for all formats, highlighting active states, auto-expanding +docstrings, Jupyter integration, Sphinx directive, and the `quickchart_write_svg` alternative that doesn't require Graphviz. ``` From 87290903871af6d3bdfd26796ab7a250d029a8e9 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 14:46:37 -0300 Subject: [PATCH 07/10] docs: revise 3.1.0 release notes and fix broken cross-references --- docs/diagram.md | 1 + docs/releases/3.1.0.md | 143 ++++++++++++++++++++++++++++------------- 2 files changed, 99 insertions(+), 45 deletions(-) diff --git a/docs/diagram.md b/docs/diagram.md index 72cda42a..7dbaa454 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -164,6 +164,7 @@ digraph TrafficLightMachine { An empty format spec (e.g., `f"{sm:}"`) falls back to `repr()`. +(formatter-api)= ### Using the `formatter` API The `formatter` object is the programmatic entry point for rendering diff --git a/docs/releases/3.1.0.md b/docs/releases/3.1.0.md index 5f08b6ed..e566ec62 100644 --- a/docs/releases/3.1.0.md +++ b/docs/releases/3.1.0.md @@ -4,75 +4,65 @@ ## What's new in 3.1.0 -### Sphinx directive for inline diagrams - -A new Sphinx extension renders state machine diagrams directly in your -documentation from an importable class path — no manual image generation -needed. +### Text representations with `format()` -Add `"statemachine.contrib.diagram.sphinx_ext"` to your `conf.py` -extensions, then use the directive in any MyST Markdown page: +State machines now support Python's built-in `format()` protocol. Use f-strings +or `format()` to get text representations — on both classes and instances: -````markdown -```{statemachine-diagram} myproject.machines.OrderControl -:events: receive_payment -:caption: After payment -:target: +```python +f"{TrafficLightMachine:md}" +f"{sm:mermaid}" +format(sm, "rst") ``` -```` -The directive supports the same options as the standard `image`/`figure` -directives (`:width:`, `:height:`, `:scale:`, `:align:`, `:target:`, -`:class:`, `:name:`), plus `:events:` to instantiate the machine and send -events before rendering (highlighting the current state). +Supported formats: -Using `:target:` without a value makes the diagram clickable, opening the -full SVG in a new browser tab for zooming — useful for large statecharts. +| Format | Output | Requires | +|-----------|---------------------------|-----------------------| +| `dot` | Graphviz DOT source | `pydot` | +| `svg` | SVG markup (via Graphviz) | `pydot` + `graphviz` | +| `mermaid` | Mermaid stateDiagram-v2 | — | +| `md` | Markdown transition table | — | +| `rst` | RST transition table | — | -See {ref}`diagram:Sphinx directive` for full documentation. -[#589](https://github.com/fgmacedo/python-statemachine/pull/589). +See {ref}`diagram:Text representations` for details. -### Performance: 5x–7x faster event processing - -The engine's hot paths have been systematically profiled and optimized, resulting in -**4.7x–7.7x faster event throughput** and **1.9x–2.6x faster setup** across all -machine types. All optimizations are internal — no public API changes. -See [#592](https://github.com/fgmacedo/python-statemachine/pull/592) for details. +### Formatter facade +A new `Formatter` facade with decorator-based registration unifies all text +format rendering behind a single API. Adding a new format requires only +registering a render function — no changes to `__format__`, the CLI, or the +Sphinx directive: -### Thread safety documentation - -The sync engine is thread-safe: multiple threads can send events to the same state -machine instance concurrently. This is now documented in the -{ref}`processing model ` and verified by stress tests. -[#592](https://github.com/fgmacedo/python-statemachine/pull/592). +```python +from statemachine.contrib.diagram import formatter +formatter.render(sm, "mermaid") +formatter.supported_formats() -### Diagram CLI `--events` option +@formatter.register_format("custom") +def _render_custom(machine_or_class): + ... +``` -The `python -m statemachine.contrib.diagram` command now accepts `--events` to -instantiate the machine and send events before rendering, highlighting the -current active state — matching the Sphinx directive's `:events:` option. -See {ref}`diagram:Command line` for details. +See {ref}`formatter-api` for details. ### Mermaid diagram support State machines can now be rendered as [Mermaid `stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) -source text — no Graphviz installation required. +source text — no Graphviz installation required. Supports compound states, +parallel regions, history states, guards, and active-state highlighting. Three ways to use it: -- **`format()` / f-strings:** `f"{sm:mermaid}"`, `f"{sm:md}"`, `f"{sm:rst}"` — - works on both instances and classes. +- **f-strings:** `f"{sm:mermaid}"` - **CLI:** `python -m statemachine.contrib.diagram MyMachine - --format mermaid` - **Sphinx directive:** `:format: mermaid` renders via `sphinxcontrib-mermaid`. -A new `TransitionTableRenderer` produces markdown or RST transition tables -from the same diagram IR. See {ref}`diagram:Text representations with format()` -and {ref}`diagram:Mermaid output` for details. +See {ref}`diagram:Mermaid format` for details. ### Auto-expanding docstrings @@ -94,13 +84,76 @@ class TrafficLight(StateChart): ``` Any registered format works: `md`, `rst`, `mermaid`, `dot`, etc. +Works with Sphinx autodoc — the expanded docstring is what gets rendered. See {ref}`diagram:Auto-expanding docstrings` for details. +### Sphinx directive for inline diagrams + +A new Sphinx extension renders state machine diagrams directly in your +documentation from an importable class path — no manual image generation +needed. + +Add `"statemachine.contrib.diagram.sphinx_ext"` to your `conf.py` +extensions, then use the directive in any MyST Markdown page: + +````markdown +```{statemachine-diagram} myproject.machines.OrderControl +:events: receive_payment +:caption: After payment +:target: +``` +```` + +The directive supports the same options as the standard `image`/`figure` +directives (`:width:`, `:height:`, `:scale:`, `:align:`, `:target:`, +`:class:`, `:name:`), plus `:events:` to instantiate the machine and send +events before rendering (highlighting the current state). + +Using `:target:` without a value makes the diagram clickable, opening the +full SVG in a new browser tab for zooming — useful for large statecharts. + +The `:format: mermaid` option renders via `sphinxcontrib-mermaid` instead of +Graphviz. + +See {ref}`diagram:Sphinx directive` for full documentation. +[#589](https://github.com/fgmacedo/python-statemachine/pull/589). + + +### Diagram CLI `--events` and `--format` options + +The `python -m statemachine.contrib.diagram` command now accepts: + +- `--events` to instantiate the machine and send events before rendering, + highlighting the current active state. +- `--format` to choose the output format (`mermaid`, `md`, `rst`, `dot`, `svg`, + or image formats via Graphviz). Use `-` as the output path to write text + formats to stdout. + +See {ref}`diagram:Command line` for details. +[#593](https://github.com/fgmacedo/python-statemachine/pull/593). + + +### Performance: 5x–7x faster event processing + +The engine's hot paths have been systematically profiled and optimized, resulting in +**4.7x–7.7x faster event throughput** and **1.9x–2.6x faster setup** across all +machine types. All optimizations are internal — no public API changes. +See [#592](https://github.com/fgmacedo/python-statemachine/pull/592) for details. + + +### Thread safety documentation + +The sync engine is thread-safe: multiple threads can send events to the same state +machine instance concurrently. This is now documented in the +{ref}`processing model ` and verified by stress tests. +[#592](https://github.com/fgmacedo/python-statemachine/pull/592). + + ### Bugfixes in 3.1.0 - Fixes silent misuse of `Event()` with multiple positional arguments. Passing more than one - transition to `Event()` (e.g., `Event(t1, t2)`) now raises {ref}`InvalidDefinition` with a + transition to `Event()` (e.g., `Event(t1, t2)`) now raises `InvalidDefinition` with a clear message suggesting the `|` operator. Previously, the second argument was silently interpreted as the event `id`, leaving the extra transitions eventless (auto-firing). [#588](https://github.com/fgmacedo/python-statemachine/pull/588). From 462a0f1e1c382c609a648f954f0dbfbfaf59907a Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 15:41:38 -0300 Subject: [PATCH 08/10] fix: work around Mermaid crash for transitions inside parallel regions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mermaid's stateDiagram-v2 crashes when a transition targets or originates from a compound state inside a parallel region (mermaid-js/mermaid#4052). The MermaidRenderer now redirects such endpoints to the compound's initial child state. Also filter dot-form event aliases (e.g. done.invoke.X) from diagram output — the fix lives in the extractor so all renderers benefit. Closes #594 --- docs/diagram.md | 12 +- statemachine/contrib/diagram/extract.py | 30 ++++- .../contrib/diagram/renderers/mermaid.py | 66 +++++++++-- tests/test_contrib_diagram.py | 103 ++++++++++++++++++ tests/test_mermaid_renderer.py | 5 +- 5 files changed, 203 insertions(+), 13 deletions(-) diff --git a/docs/diagram.md b/docs/diagram.md index 7dbaa454..17243490 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -95,12 +95,22 @@ the command line. | Format | Aliases | Description | Dependencies | |--------|---------|-------------|--------------| -| `mermaid` | | [Mermaid stateDiagram-v2](https://mermaid.js.org/syntax/stateDiagram.html) source | None | +| `mermaid` | | [Mermaid stateDiagram-v2](https://mermaid.js.org/syntax/stateDiagram.html) source | None [^mermaid] | | `md` | `markdown` | Transition table (pipe-delimited Markdown) | None | | `rst` | | Transition table (RST grid table) | None | | `dot` | | [Graphviz DOT](https://graphviz.org/doc/info/lang.html) language source | pydot | | `svg` | | SVG markup (generated via DOT) | pydot, Graphviz | +[^mermaid]: Mermaid has a known rendering bug + ([mermaid-js/mermaid#4052](https://github.com/mermaid-js/mermaid/issues/4052)) + where transitions targeting or originating from a compound state inside a + parallel region crash the renderer. As a workaround, the `MermaidRenderer` + redirects such transitions to the compound's initial child state. The + visual result is equivalent — Mermaid draws the arrow crossing into the + compound boundary — but the arrow points to the child rather than the + compound border. This workaround will be revisited when the upstream bug + is resolved. + ### Using `format()` diff --git a/statemachine/contrib/diagram/extract.py b/statemachine/contrib/diagram/extract.py index 37f4fc88..29002e16 100644 --- a/statemachine/contrib/diagram/extract.py +++ b/statemachine/contrib/diagram/extract.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from statemachine.state import State from statemachine.statemachine import StateChart + from statemachine.transition import Transition # A StateChart class or instance — both expose the same structural metadata. MachineRef = Union["StateChart", "type[StateChart]"] @@ -101,6 +102,33 @@ def _extract_state( ) +def _format_event_names(transition: "Transition") -> str: + """Build a display string for the events that trigger a transition. + + ``_expand_event_id`` registers both the Python attribute name + (``done_invoke_X``) and the SCXML dot form (``done.invoke.X``) under the + same transition. For diagram display we only want unique *semantic* events, + keeping the Python attribute name when an alias pair exists. + """ + events = list(transition.events) + if not events: + return "" + + all_ids = {str(e) for e in events} + + display: List[str] = [] + for event in events: + eid = str(event) + # Skip dot-form aliases (e.g. "done.invoke.X") when the underscore + # form ("done_invoke_X") is also registered on this transition. + if "." in eid and eid.replace(".", "_") in all_ids: + continue + if eid not in display: + display.append(eid) + + return " ".join(display) + + def _extract_transitions_from_state(state: "State") -> List[DiagramTransition]: """Extract transitions from a single state (non-recursive).""" result: List[DiagramTransition] = [] @@ -114,7 +142,7 @@ def _extract_transitions_from_state(state: "State") -> List[DiagramTransition]: DiagramTransition( source=transition.source.id, targets=target_ids, - event=transition.event, + event=_format_event_names(transition), guards=cond_strs, is_internal=transition.internal, ) diff --git a/statemachine/contrib/diagram/renderers/mermaid.py b/statemachine/contrib/diagram/renderers/mermaid.py index a2e649fb..b455b935 100644 --- a/statemachine/contrib/diagram/renderers/mermaid.py +++ b/statemachine/contrib/diagram/renderers/mermaid.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Dict from typing import List from typing import Optional from typing import Set @@ -21,17 +22,34 @@ class MermaidRendererConfig: class MermaidRenderer: - """Renders a DiagramGraph into a Mermaid stateDiagram-v2 source string.""" + """Renders a DiagramGraph into a Mermaid stateDiagram-v2 source string. + + Mermaid's stateDiagram-v2 has a rendering bug where transitions whose source + or target is a compound state (``state X { ... }``) inside a parallel region + crash with ``Cannot set properties of undefined (setting 'rank')``. To work + around this, the renderer rewrites compound-state endpoints to cross the + boundary: + + - Transition **to** a compound → redirected to its initial child. + - Transition **from** a compound → redirected from its initial child. + + This is applied universally (not only inside parallel regions) for simplicity + and consistency — the visual effect is equivalent. + """ def __init__(self, config: Optional[MermaidRendererConfig] = None): self.config = config or MermaidRendererConfig() self._active_ids: List[str] = [] self._rendered_transitions: Set[tuple] = set() + self._compound_ids: Set[str] = set() + self._initial_child_map: Dict[str, str] = {} def render(self, graph: DiagramGraph) -> str: """Render a DiagramGraph to a Mermaid stateDiagram-v2 string.""" self._active_ids = [] self._rendered_transitions = set() + self._compound_ids = graph.compound_state_ids + self._initial_child_map = self._build_initial_child_map(graph.states) lines: List[str] = [] lines.append("stateDiagram-v2") @@ -51,6 +69,23 @@ def render(self, graph: DiagramGraph) -> str: return "\n".join(lines) + "\n" + def _build_initial_child_map(self, states: List[DiagramState]) -> Dict[str, str]: + """Build a map from compound state ID to its initial child ID (recursive).""" + result: Dict[str, str] = {} + for state in states: + if state.children: + initial = next((c for c in state.children if c.is_initial), None) + if initial: + result[state.id] = initial.id + result.update(self._build_initial_child_map(state.children)) + return result + + def _resolve_endpoint(self, state_id: str) -> str: + """Resolve a transition endpoint, redirecting compound states to their initial child.""" + if state_id in self._compound_ids and state_id in self._initial_child_map: + return self._initial_child_map[state_id] + return state_id + def _render_states( self, states: List[DiagramState], @@ -162,29 +197,42 @@ def _render_scope_transitions( lines: List[str], indent: int, ) -> None: - """Render transitions where both source and all targets are in scope_ids.""" + """Render transitions where both source and all targets are in scope_ids. + + Mermaid does not support transitions where the source or target is a + compound state rendered with ``state X { ... }`` inside a parallel region. + To work around this, endpoints that reference compound states are + redirected to the compound's initial child. Scope membership is checked + on the **original** IDs (which belong to this scope level), while the + rendered arrow uses the **resolved** (possibly redirected) IDs. + """ for t in transitions: if t.is_initial or t.is_internal: continue targets = t.targets if t.targets else [t.source] - # Only render if source is in scope + + # Check scope membership with original IDs if t.source not in scope_ids: continue - # Only render if all targets are in scope if not all(target in scope_ids for target in targets): continue - for target in targets: - key = (t.source, target, t.event) + # Resolve endpoints for rendering (redirect compound → initial child) + source = self._resolve_endpoint(t.source) + resolved_targets = [self._resolve_endpoint(tid) for tid in targets] + + for target in resolved_targets: + key = (source, target, t.event) if key in self._rendered_transitions: continue self._rendered_transitions.add(key) - self._render_single_transition(t, target, lines, indent) + self._render_single_transition(t, source, target, lines, indent) def _render_single_transition( self, transition: DiagramTransition, + source: str, target: str, lines: List[str], indent: int, @@ -198,9 +246,9 @@ def _render_single_transition( label = " ".join(label_parts) if label: - lines.append(f"{pad}{transition.source} --> {target} : {label}") + lines.append(f"{pad}{source} --> {target} : {label}") else: - lines.append(f"{pad}{transition.source} --> {target}") + lines.append(f"{pad}{source} --> {target}") @staticmethod def _format_action(action: DiagramAction) -> str: diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index 0505e013..14575faa 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -8,6 +8,7 @@ from statemachine.contrib.diagram import DotGraphMachine from statemachine.contrib.diagram import main from statemachine.contrib.diagram import quickchart_write_svg +from statemachine.contrib.diagram.extract import _format_event_names from statemachine.contrib.diagram.model import ActionType from statemachine.contrib.diagram.model import StateType from statemachine.contrib.diagram.renderers.dot import DotRenderer @@ -698,6 +699,108 @@ def test_resolve_initial_fallback(self): assert states[0].is_initial is True +class TestFormatEventNames: + """Tests for _format_event_names — alias filtering for diagram display.""" + + def test_simple_event_unchanged(self): + """A plain event with no aliases is returned as-is.""" + + class SM(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + t = SM.s1.transitions[0] + assert _format_event_names(t) == "go" + + def test_done_state_alias_filtered(self): + """done_state_X registers both underscore and dot forms; only underscore is shown.""" + + class SM(StateChart): + class parent(State.Compound): + child = State(initial=True) + done = State(final=True) + finish = child.to(done) + + end = State(final=True) + done_state_parent = parent.to(end) + + t = next(t for t in SM.parent.transitions if t.event and "done_state" in t.event) + result = _format_event_names(t) + assert result == "done_state_parent" + assert "done.state" not in result + + def test_done_invoke_alias_filtered(self): + """done_invoke_X alias filtering works the same as done_state_X.""" + + class SM(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + done_invoke_child = s1.to(s2) + + t = SM.s1.transitions[0] + result = _format_event_names(t) + assert result == "done_invoke_child" + assert "done.invoke" not in result + + def test_error_alias_filtered(self): + """error_X registers both error_X and error.X; only underscore is shown.""" + + class SM(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + error_execution = s1.to(s2) + + t = SM.s1.transitions[0] + result = _format_event_names(t) + assert result == "error_execution" + assert "error.execution" not in result + + def test_multiple_distinct_events_preserved(self): + """Multiple distinct events on one transition are all preserved.""" + + class SM(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + also = s1.to(s2) + + # Add a second event to the first transition + t = SM.s1.transitions[0] + t.add_event("also") + result = _format_event_names(t) + assert "go" in result + assert "also" in result + + def test_eventless_transition_returns_empty(self): + """A transition with no events returns an empty string.""" + + class SM(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + s1.to(s2, cond="always_true") + + def always_true(self): + return True + + # Find the eventless transition + t = next(t for t in SM.s1.transitions if not list(t.events)) + assert _format_event_names(t) == "" + + def test_dot_only_event_preserved(self): + """An event whose ID contains dots but has no underscore alias is preserved.""" + + class SM(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + from statemachine.transition import Transition + + t = Transition(source=SM.s1, target=SM.s2, event="custom.event") + assert _format_event_names(t) == "custom.event" + + class TestDotRendererEdgeCases: """Tests for dot.py edge cases.""" diff --git a/tests/test_mermaid_renderer.py b/tests/test_mermaid_renderer.py index 19562dab..d4042cbc 100644 --- a/tests/test_mermaid_renderer.py +++ b/tests/test_mermaid_renderer.py @@ -249,8 +249,9 @@ class parent(State.Compound, name="Parent"): assert "[*] --> child1" in result assert "child1 --> child2 : go" in result assert "child2 --> [*]" in result - assert "start --> parent : enter" in result - assert "parent --> end : finish" in result + # Compound endpoints are redirected to the initial child (Mermaid workaround) + assert "start --> child1 : enter" in result + assert "child1 --> end : finish" in result def test_compound_no_duplicate_transitions(self): """Transitions inside compound states must not also appear at top level.""" From 9dca47b6b608552d21306bcb15525ebd8dc9d553 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 16:59:25 -0300 Subject: [PATCH 09/10] fix: render cross-boundary transitions and restrict Mermaid compound workaround to parallel regions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Mermaid renderer had two issues: 1. Cross-scope transitions (e.g., an outer state targeting a history pseudo-state inside a compound) were silently dropped because `_render_scope_transitions` only rendered transitions where both endpoints were direct members of the same scope. Now the scope check expands to include descendants of compound states, while skipping transitions fully internal to a single compound (handled by the inner scope). 2. The compound→initial-child redirect (workaround for mermaid-js/mermaid#4052) was applied universally, but the bug only affects compound states inside parallel regions. Now the redirect is restricted to parallel descendants, leaving compound states outside parallel regions unchanged. Adds a ParallelCompoundSC showcase that exercises the Mermaid bug pattern (transition targeting a compound inside a parallel region), with Graphviz vs Mermaid comparison in the visual showcase docs. --- docs/diagram.md | 104 +++++++++++++-- .../contrib/diagram/renderers/mermaid.py | 123 ++++++++++++++---- tests/machines/showcase_parallel_compound.py | 34 +++++ tests/test_mermaid_renderer.py | 97 ++++++++++++-- 4 files changed, 314 insertions(+), 44 deletions(-) create mode 100644 tests/machines/showcase_parallel_compound.py diff --git a/docs/diagram.md b/docs/diagram.md index 17243490..ff3962df 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -545,9 +545,9 @@ dot.write_png("order_control_class.png") ## Visual showcase This section shows how each state machine feature is rendered in diagrams. -Each example includes the class definition, the **class** diagram (no -active state), and **instance** diagrams (with the current state -highlighted after sending events). +Each example includes the class definition, diagrams in both **Graphviz** +and **Mermaid** formats, and **instance** diagrams with the current state +highlighted after sending events. ### Simple states @@ -560,7 +560,12 @@ A minimal state machine with three atomic states and linear transitions. ``` ```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC @@ -589,7 +594,12 @@ States can declare `entry` / `exit` callbacks, shown in the state label. ``` ```{statemachine-diagram} tests.machines.showcase_actions.ActionsSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_actions.ActionsSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_actions.ActionsSC @@ -608,7 +618,12 @@ Transitions can have `cond` guards, shown in brackets on the edge label. ``` ```{statemachine-diagram} tests.machines.showcase_guards.GuardSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_guards.GuardSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_guards.GuardSC @@ -627,7 +642,12 @@ A transition from a state back to itself. ``` ```{statemachine-diagram} tests.machines.showcase_self_transition.SelfTransitionSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_self_transition.SelfTransitionSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_self_transition.SelfTransitionSC @@ -646,7 +666,12 @@ Internal transitions execute actions without exiting/entering the state. ``` ```{statemachine-diagram} tests.machines.showcase_internal.InternalSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_internal.InternalSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_internal.InternalSC @@ -666,10 +691,15 @@ its initial child. ``` ```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC -:caption: Class +:caption: Class (Graphviz) :target: ``` +```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC +:format: mermaid +:caption: Class (Mermaid) +``` + ```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC :events: :caption: Off @@ -699,10 +729,15 @@ A parallel state activates all its regions simultaneously. ``` ```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC -:caption: Class +:caption: Class (Graphviz) :target: ``` +```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC +:format: mermaid +:caption: Class (Mermaid) +``` + ```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC :events: enter :caption: Both active @@ -716,6 +751,41 @@ A parallel state activates all its regions simultaneously. ``` +### Parallel with cross-boundary transitions + +A transition targeting a compound state **inside** a parallel region triggers a +rendering bug in Mermaid (`mermaid-js/mermaid#4052`). The Mermaid renderer works +around this by redirecting the arrow to the compound's initial child — compare the +``rebuild`` arrow in both diagrams below. + +```{literalinclude} ../tests/machines/showcase_parallel_compound.py +:pyobject: ParallelCompoundSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_parallel_compound.ParallelCompoundSC +:caption: Class (Graphviz) — ``rebuild`` points to the Build compound border +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_parallel_compound.ParallelCompoundSC +:format: mermaid +:caption: Class (Mermaid) — ``rebuild`` is redirected to Compile (initial child of Build) +``` + +```{statemachine-diagram} tests.machines.showcase_parallel_compound.ParallelCompoundSC +:events: start, do_build +:caption: Build done +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_parallel_compound.ParallelCompoundSC +:events: start, do_build, do_test +:caption: Pipeline done → Review +:target: +``` + + ### History states (shallow) A history pseudo-state remembers the last active child of a compound state. @@ -726,10 +796,15 @@ A history pseudo-state remembers the last active child of a compound state. ``` ```{statemachine-diagram} tests.machines.showcase_history.HistorySC -:caption: Class +:caption: Class (Graphviz) :target: ``` +```{statemachine-diagram} tests.machines.showcase_history.HistorySC +:format: mermaid +:caption: Class (Mermaid) +``` + ```{statemachine-diagram} tests.machines.showcase_history.HistorySC :events: begin, advance :caption: Step2 @@ -759,10 +834,15 @@ Deep history remembers the exact leaf state across nested compounds. ``` ```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC -:caption: Class +:caption: Class (Graphviz) :target: ``` +```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC +:format: mermaid +:caption: Class (Mermaid) +``` + ```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC :events: dive, enter_inner, go :caption: Inner/B diff --git a/statemachine/contrib/diagram/renderers/mermaid.py b/statemachine/contrib/diagram/renderers/mermaid.py index b455b935..15ba61d5 100644 --- a/statemachine/contrib/diagram/renderers/mermaid.py +++ b/statemachine/contrib/diagram/renderers/mermaid.py @@ -24,17 +24,14 @@ class MermaidRendererConfig: class MermaidRenderer: """Renders a DiagramGraph into a Mermaid stateDiagram-v2 source string. - Mermaid's stateDiagram-v2 has a rendering bug where transitions whose source - or target is a compound state (``state X { ... }``) inside a parallel region - crash with ``Cannot set properties of undefined (setting 'rank')``. To work - around this, the renderer rewrites compound-state endpoints to cross the - boundary: - - - Transition **to** a compound → redirected to its initial child. - - Transition **from** a compound → redirected from its initial child. - - This is applied universally (not only inside parallel regions) for simplicity - and consistency — the visual effect is equivalent. + Mermaid's stateDiagram-v2 has a rendering bug + (`mermaid-js/mermaid#4052 `_) + where transitions whose source or target is a compound state + (``state X { ... }``) **inside a parallel region** crash with + ``Cannot set properties of undefined (setting 'rank')``. To work around + this, the renderer rewrites compound-state endpoints that are descendants + of a parallel state, redirecting them to the compound's initial child. + Compound states outside parallel regions are left unchanged. """ def __init__(self, config: Optional[MermaidRendererConfig] = None): @@ -43,6 +40,8 @@ def __init__(self, config: Optional[MermaidRendererConfig] = None): self._rendered_transitions: Set[tuple] = set() self._compound_ids: Set[str] = set() self._initial_child_map: Dict[str, str] = {} + self._parallel_descendant_ids: Set[str] = set() + self._all_descendants_map: Dict[str, Set[str]] = {} def render(self, graph: DiagramGraph) -> str: """Render a DiagramGraph to a Mermaid stateDiagram-v2 string.""" @@ -50,6 +49,8 @@ def render(self, graph: DiagramGraph) -> str: self._rendered_transitions = set() self._compound_ids = graph.compound_state_ids self._initial_child_map = self._build_initial_child_map(graph.states) + self._parallel_descendant_ids = self._collect_parallel_descendants(graph.states) + self._all_descendants_map = self._build_all_descendants_map(graph.states) lines: List[str] = [] lines.append("stateDiagram-v2") @@ -80,9 +81,52 @@ def _build_initial_child_map(self, states: List[DiagramState]) -> Dict[str, str] result.update(self._build_initial_child_map(state.children)) return result + @staticmethod + def _collect_parallel_descendants( + states: List[DiagramState], + inside_parallel: bool = False, + ) -> Set[str]: + """Collect IDs of all states that are descendants of a parallel state.""" + result: Set[str] = set() + for state in states: + if inside_parallel: + result.add(state.id) + child_inside = inside_parallel or state.type == StateType.PARALLEL + result.update( + MermaidRenderer._collect_parallel_descendants(state.children, child_inside) + ) + return result + + def _build_all_descendants_map(self, states: List[DiagramState]) -> Dict[str, Set[str]]: + """Map each compound state ID to the set of all its descendant IDs.""" + result: Dict[str, Set[str]] = {} + for state in states: + if state.children: + result[state.id] = self._collect_recursive_descendants(state.children) + result.update(self._build_all_descendants_map(state.children)) + return result + + @staticmethod + def _collect_recursive_descendants(states: List[DiagramState]) -> Set[str]: + """Collect all state IDs in a subtree recursively.""" + ids: Set[str] = set() + for s in states: + ids.add(s.id) + ids.update(MermaidRenderer._collect_recursive_descendants(s.children)) + return ids + def _resolve_endpoint(self, state_id: str) -> str: - """Resolve a transition endpoint, redirecting compound states to their initial child.""" - if state_id in self._compound_ids and state_id in self._initial_child_map: + """Resolve a transition endpoint for Mermaid compatibility. + + Only redirects compound states that are inside a parallel region — + this is where Mermaid's rendering bug (mermaid-js/mermaid#4052) occurs. + Compound states outside parallel regions are left unchanged. + """ + if ( + state_id in self._compound_ids + and state_id in self._parallel_descendant_ids + and state_id in self._initial_child_map + ): return self._initial_child_map[state_id] return state_id @@ -197,25 +241,44 @@ def _render_scope_transitions( lines: List[str], indent: int, ) -> None: - """Render transitions where both source and all targets are in scope_ids. - - Mermaid does not support transitions where the source or target is a - compound state rendered with ``state X { ... }`` inside a parallel region. - To work around this, endpoints that reference compound states are - redirected to the compound's initial child. Scope membership is checked - on the **original** IDs (which belong to this scope level), while the - rendered arrow uses the **resolved** (possibly redirected) IDs. + """Render transitions that belong to this scope level. + + A transition belongs to scope S if all its endpoints are *reachable* + from S (either directly in S or descendants of a compound in S) **and** + the transition is not fully internal to a single compound in S (those + are rendered by the compound's inner scope). + + This allows cross-boundary transitions (e.g., an outer state targeting + a history pseudo-state inside a compound) to be rendered at the correct + scope level — Mermaid draws the arrow crossing the compound border. + + Mermaid crashes when the source or target is a compound state inside a + parallel region (mermaid-js/mermaid#4052). For those cases, endpoints + are redirected to the compound's initial child via ``_resolve_endpoint``. """ + # Build the descendant sets for compounds in this scope + compound_descendants: Dict[str, Set[str]] = {} + expanded: Set[str] = set(scope_ids) + for sid in scope_ids: + if sid in self._all_descendants_map: + compound_descendants[sid] = self._all_descendants_map[sid] + expanded |= self._all_descendants_map[sid] + for t in transitions: if t.is_initial or t.is_internal: continue targets = t.targets if t.targets else [t.source] - # Check scope membership with original IDs - if t.source not in scope_ids: + # All endpoints must be reachable from this scope + if t.source not in expanded: continue - if not all(target in scope_ids for target in targets): + if not all(target in expanded for target in targets): + continue + + # Skip transitions fully internal to a single compound — + # those will be rendered by the compound's inner scope. + if self._is_fully_internal(t.source, targets, compound_descendants): continue # Resolve endpoints for rendering (redirect compound → initial child) @@ -229,6 +292,18 @@ def _render_scope_transitions( self._rendered_transitions.add(key) self._render_single_transition(t, source, target, lines, indent) + @staticmethod + def _is_fully_internal( + source: str, + targets: List[str], + compound_descendants: Dict[str, Set[str]], + ) -> bool: + """Check if all endpoints belong to the same compound's descendants.""" + for descendants in compound_descendants.values(): + if source in descendants and all(tgt in descendants for tgt in targets): + return True + return False + def _render_single_transition( self, transition: DiagramTransition, diff --git a/tests/machines/showcase_parallel_compound.py b/tests/machines/showcase_parallel_compound.py new file mode 100644 index 00000000..049def76 --- /dev/null +++ b/tests/machines/showcase_parallel_compound.py @@ -0,0 +1,34 @@ +from statemachine import State +from statemachine import StateChart + + +class ParallelCompoundSC(StateChart): + """Parallel regions with a cross-boundary transition into an inner compound. + + The ``rebuild`` transition targets ``pipeline.build`` — a compound state + inside a parallel region. This is the exact pattern that triggers + `mermaid-js/mermaid#4052 `_; + the Mermaid renderer works around it by redirecting the arrow to the + compound's initial child. + + {statechart:rst} + """ + + class pipeline(State.Parallel, name="Pipeline"): + class build(State.Compound, name="Build"): + compile = State(initial=True) + link = State(final=True) + do_build = compile.to(link) + + class test(State.Compound, name="Test"): + unit = State(initial=True) + e2e = State(final=True) + do_test = unit.to(e2e) + + idle = State(initial=True) + review = State() + + start = idle.to(pipeline) + done_state_pipeline = pipeline.to(review) + rebuild = review.to(pipeline.build) + accept = review.to(idle) diff --git a/tests/test_mermaid_renderer.py b/tests/test_mermaid_renderer.py index d4042cbc..32cf518b 100644 --- a/tests/test_mermaid_renderer.py +++ b/tests/test_mermaid_renderer.py @@ -249,9 +249,8 @@ class parent(State.Compound, name="Parent"): assert "[*] --> child1" in result assert "child1 --> child2 : go" in result assert "child2 --> [*]" in result - # Compound endpoints are redirected to the initial child (Mermaid workaround) - assert "start --> child1 : enter" in result - assert "child1 --> end : finish" in result + assert "start --> parent : enter" in result + assert "parent --> end : finish" in result def test_compound_no_duplicate_transitions(self): """Transitions inside compound states must not also appear at top level.""" @@ -289,6 +288,47 @@ class r2(State.Compound, name="Region2"): assert 'state "Parallel" as p {' in result assert "--" in result # parallel separator + def test_parallel_redirects_compound_endpoints(self): + """Transitions to/from compound states inside parallel regions are redirected + to the initial child (Mermaid workaround for mermaid-js/mermaid#4052).""" + + class SM(StateChart): + class p(State.Parallel, name="Parallel"): + class region1(State.Compound, name="Region1"): + idle = State(initial=True) + + class inner(State.Compound, name="Inner"): + working = State(initial=True) + + start = idle.to(inner) + + class region2(State.Compound, name="Region2"): + x = State(initial=True) + + begin = State(initial=True) + enter = begin.to(p) + + result = MermaidGraphMachine(SM).get_mermaid() + # Inside parallel: compound endpoint redirected to initial child + assert "idle --> working : start" in result + assert "idle --> inner" not in result + + def test_compound_outside_parallel_not_redirected(self): + """Compound states outside parallel regions keep direct transitions.""" + + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child = State(initial=True) + + start = State(initial=True) + end = State(final=True) + enter = start.to(parent) + leave = parent.to(end) + + result = MermaidGraphMachine(SM).get_mermaid() + assert "start --> parent : enter" in result + assert "parent --> end : leave" in result + def test_nested_compound(self): class SM(StateChart): class outer(State.Compound, name="Outer"): @@ -523,8 +563,8 @@ def test_active_compound_state(self): result = MermaidRenderer().render(graph) assert "comp:::active" in result - def test_cross_scope_transition_not_in_compound(self): - """Transition crossing compound boundaries is not rendered inside the compound.""" + def test_cross_scope_transition_rendered_at_parent(self): + """Transition crossing compound boundaries is rendered at the parent scope.""" graph = DiagramGraph( name="CrossScope", states=[ @@ -546,9 +586,50 @@ def test_cross_scope_transition_not_in_compound(self): ) result = MermaidRenderer().render(graph) # c1 is inside comp, outside is at top level — the transition - # can't be rendered at either scope since source/target span scopes. - # This is expected: Mermaid doesn't support cross-scope transitions natively. - assert "c1 --> outside" not in result + # crosses the compound boundary and is rendered at the top scope. + assert "c1 --> outside : leave" in result + # It should NOT appear inside the compound block + lines = result.split("\n") + for line in lines: + if "c1 --> outside" in line: + # Should be at indent level 1 (top scope), not deeper + assert line.startswith(" c1"), f"Expected top-level indent, got: {line!r}" + + def test_cross_scope_to_history_state(self): + """Transition from outside a compound to a history state inside it is rendered.""" + graph = DiagramGraph( + name="HistoryCross", + states=[ + DiagramState( + id="process", + name="Process", + type=StateType.REGULAR, + children=[ + DiagramState( + id="step1", name="Step1", type=StateType.REGULAR, is_initial=True + ), + DiagramState(id="step2", name="Step2", type=StateType.REGULAR), + DiagramState(id="h", name="H", type=StateType.HISTORY_SHALLOW), + ], + ), + DiagramState(id="paused", name="Paused", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="step1", targets=["step2"], event="advance"), + DiagramTransition(source="process", targets=["paused"], event="pause"), + DiagramTransition(source="paused", targets=["h"], event="resume"), + DiagramTransition(source="paused", targets=["process"], event="begin"), + ], + compound_state_ids={"process"}, + ) + result = MermaidRenderer().render(graph) + # The resume transition crosses the compound boundary + assert "paused --> h : resume" in result + # advance stays inside the compound + assert "step1 --> step2 : advance" in result + # pause and begin are at top level (both endpoints are top-level) + assert "process --> paused : pause" in result + assert "paused --> process : begin" in result def test_no_initial_state(self): """Graph with no initial state omits [*] arrow.""" From fc082f2abaf131801f43a8f46fa75ce6bee13bb2 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Sun, 8 Mar 2026 17:13:48 -0300 Subject: [PATCH 10/10] test: cover missing branches in sphinx_ext.py and extract.py Add tests for the :name: directive option on Mermaid format (with and without caption). Mark the defensive dedup guard in _format_event_names as pragma: no branch since Events already deduplicates at the container level. --- statemachine/contrib/diagram/extract.py | 2 +- tests/test_contrib_diagram.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/statemachine/contrib/diagram/extract.py b/statemachine/contrib/diagram/extract.py index 29002e16..15a1f2da 100644 --- a/statemachine/contrib/diagram/extract.py +++ b/statemachine/contrib/diagram/extract.py @@ -123,7 +123,7 @@ def _format_event_names(transition: "Transition") -> str: # form ("done_invoke_X") is also registered on this transition. if "." in eid and eid.replace(".", "_") in all_ids: continue - if eid not in display: + if eid not in display: # pragma: no branch display.append(eid) return " ".join(display) diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index 14575faa..2da3cdd8 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -1737,6 +1737,22 @@ def test_mermaid_format_with_caption(self, tmp_path): assert len(caption_children) == 1 assert caption_children[0].astext() == "My Diagram" + def test_mermaid_format_with_caption_and_name(self, tmp_path): + """Mermaid format with caption and :name: calls add_name on the figure.""" + _, result = self._run( + tmp_path, options={"format": "mermaid", "caption": "My Diagram", "name": "fig-sm"} + ) + assert len(result) == 1 + assert isinstance(result[0], nodes.figure) + + def test_mermaid_format_with_name_no_caption(self, tmp_path): + """Mermaid format with :name: but no caption calls add_name on the mermaid node.""" + from sphinxcontrib.mermaid import mermaid as MermaidNode # type: ignore[import-untyped] + + _, result = self._run(tmp_path, options={"format": "mermaid", "name": "fig-sm"}) + assert len(result) == 1 + assert isinstance(result[0], MermaidNode) + def test_mermaid_format_fallback_no_sphinxcontrib(self, tmp_path): """When sphinxcontrib-mermaid is not available, falls back to code block.""" import sys