diff --git a/src/strands_compose/__init__.py b/src/strands_compose/__init__.py index 43c835c..2dcbae7 100644 --- a/src/strands_compose/__init__.py +++ b/src/strands_compose/__init__.py @@ -26,6 +26,7 @@ from .tools import ( node_as_async_tool, node_as_tool, + serialize_multiagent_result, ) from .types import EventType, StreamEvent from .utils import cli_errors @@ -61,4 +62,5 @@ "node_as_async_tool", "node_as_tool", "resolve_infra", + "serialize_multiagent_result", ] diff --git a/src/strands_compose/cli.py b/src/strands_compose/cli.py index 1d60a62..1a6a45e 100644 --- a/src/strands_compose/cli.py +++ b/src/strands_compose/cli.py @@ -118,7 +118,7 @@ def _render_check_success_ansi(app_config: AppConfig) -> None: # Collect rows as (label, value) pairs, then align on the colon. rows: list[tuple[str, str]] = [ - ("entry", str(app_config.entry)), + ("entry", app_config.entry), ("agents", agent_str), ] if app_config.models: diff --git a/src/strands_compose/tools/__init__.py b/src/strands_compose/tools/__init__.py index da3d51e..33bd0ee 100644 --- a/src/strands_compose/tools/__init__.py +++ b/src/strands_compose/tools/__init__.py @@ -4,10 +4,12 @@ - Loading ``@tool``-decorated functions from files, modules, and directories. - Wrapping ``Agent`` / ``MultiAgentBase`` nodes as ``AgentTool`` instances (``node_as_tool``, ``node_as_async_tool``) for delegation. +- Serializing multi-agent results with full execution metadata. """ from __future__ import annotations +from .extractors import serialize_multiagent_result from .loaders import ( load_tool_function, load_tools_from_directory, @@ -30,4 +32,5 @@ "node_as_tool", "resolve_tool_spec", "resolve_tool_specs", + "serialize_multiagent_result", ] diff --git a/src/strands_compose/tools/extractors.py b/src/strands_compose/tools/extractors.py index 341c51d..d1a7e68 100644 --- a/src/strands_compose/tools/extractors.py +++ b/src/strands_compose/tools/extractors.py @@ -1,11 +1,4 @@ -"""Message extraction utilities for agent and multi-agent results. - -Key Features: - - Extract the last message from strands Agent and MultiAgent results - - Extract text from messages when a string-only fallback is needed - - Support for SwarmResult and GraphResult node resolution - - Recursive extraction through nested orchestration results -""" +"""Message extraction and serialization utilities for agent and multi-agent results.""" from __future__ import annotations @@ -24,73 +17,43 @@ def _message_from_text(text: str) -> Message: return {"role": "assistant", "content": [{"text": text}]} -def _extract_last_message_from_multi_agent_result(result: MultiAgentResult) -> Message: - """Extract the final message from a ``MultiAgentResult``.""" - last_node_id = resolve_last_node_id(result) - - if last_node_id and last_node_id in result.results: - message = extract_last_message(result.results[last_node_id]) - if message is not None: - return message - - for node_result in reversed(list(result.results.values())): - message = extract_last_message(node_result) - if message is not None: - return message - - logger.warning("status=<%s> | no message extracted from MultiAgentResult", result.status) - return _message_from_text( - f"[orchestration completed with status {result.status.value} but produced no message output]" - ) - - -def extract_text_from_message(message: Message | None) -> str | None: - """Extract the last text block from a message. - - Strands ``ContentBlock`` uses ``{"text": "..."}`` for text blocks (no - ``"type"`` wrapper). This helper scans content blocks in reverse and - returns the last text block. Use it only when a caller explicitly needs - plain text; ``extract_last_message`` preserves the complete message. - - Args: - message: A strands ``Message`` dict (e.g. ``AgentResult.message``). - - Returns: - The last text string, or ``None`` if no text blocks exist. - """ +def extract_text(message: Message | None) -> str: + """Return the last text block from a message, or an empty string.""" if not message: - return None - content = message.get("content", []) - for block in reversed(content): + return "" + for block in reversed(message.get("content", [])): if isinstance(block, dict) and "text" in block: return block["text"] - return None + return "" def extract_last_message(result: Any) -> Message: """Extract the final message from an agent, orchestration, or node result. - Dispatches to the appropriate extractor based on the result type: - - ``AgentResult`` returns ``result.message`` directly. - - ``MultiAgentResult`` drills into the last executing node's message. - - ``NodeResult`` unwraps the inner payload and dispatches recursively. - - Unknown types fall back to an assistant text message containing - ``str(result)``. - Args: result: An ``AgentResult``, ``MultiAgentResult``, ``NodeResult``, or any object. Returns: - The extracted ``Message``. This can be wrapped in a one-item list and - passed to ``Agent.invoke_async`` as ``Messages`` when richer content - such as images or documents must be preserved. + The extracted ``Message``. """ if isinstance(result, AgentResult): return result.message if isinstance(result, MultiAgentResult): - return _extract_last_message_from_multi_agent_result(result) + last_node_id = resolve_last_node_id(result) + if last_node_id and last_node_id in result.results: + message = extract_last_message(result.results[last_node_id]) + if message is not None: + return message + for node_result in reversed(list(result.results.values())): + message = extract_last_message(node_result) + if message is not None: + return message + logger.warning("status=<%s> | no message extracted from MultiAgentResult", result.status) + return _message_from_text( + f"[orchestration completed with status {result.status.value} but produced no message output]" + ) if isinstance(result, NodeResult): inner = result.result @@ -125,3 +88,67 @@ def resolve_last_node_id(result: MultiAgentResult) -> str | None: return str(execution_order[-1].node_id) return None + + +def serialize_multiagent_result(result: MultiAgentResult) -> dict[str, Any]: + """Serialize a ``MultiAgentResult`` with execution metadata omitted by ``to_dict()``. + + Extends ``result.to_dict()`` with fields only available on the live object: + + - ``last_node_id`` — id of the truly last executing node, derived from + ``node_history`` / ``execution_order`` (not dict insertion order). + - ``response`` — plain-text answer from that node, ready to use + directly without any further extraction. + - ``swarm.node_history`` — full ordered execution trace including repeated + visits (``SwarmResult`` only). + - ``graph.execution_order``, ``graph.edges``, ``graph.entry_points``, and + node counts (``GraphResult`` only). + + Args: + result: A live ``MultiAgentResult``, ``SwarmResult``, or ``GraphResult`` + returned directly by ``invoke_async``. + + Returns: + A JSON-serializable dict extending ``result.to_dict()``. + """ + data = result.to_dict() + + last_node_id = resolve_last_node_id(result) + data["last_node_id"] = last_node_id + + final_message = extract_last_message(result) + data["response"] = extract_text(final_message) + + # SwarmResult extras — node_history is a list[SwarmNode] + node_history: list[Any] | None = getattr(result, "node_history", None) + if node_history is not None: + data["swarm"] = { + "node_history": [str(n.node_id) for n in node_history], + } + + # GraphResult extras — execution_order, edges, node counts + execution_order: list[Any] | None = getattr(result, "execution_order", None) + if execution_order is not None: + edges_raw: list[Any] = getattr(result, "edges", []) or [] + entry_points_raw: list[Any] = getattr(result, "entry_points", []) or [] + + edges: list[list[str]] = [] + for edge in edges_raw: + if isinstance(edge, tuple) and len(edge) == 2: + edges.append([str(edge[0].node_id), str(edge[1].node_id)]) + else: + # GraphEdge dataclass with from_node / to_node attributes + from_id = str(getattr(getattr(edge, "from_node", None), "node_id", edge)) + to_id = str(getattr(getattr(edge, "to_node", None), "node_id", edge)) + edges.append([from_id, to_id]) + + data["graph"] = { + "execution_order": [str(n.node_id) for n in execution_order], + "edges": edges, + "entry_points": [str(getattr(ep, "node_id", ep)) for ep in entry_points_raw], + "completed_nodes": getattr(result, "completed_nodes", 0), + "failed_nodes": getattr(result, "failed_nodes", 0), + "interrupted_nodes": getattr(result, "interrupted_nodes", 0), + } + + return data diff --git a/src/strands_compose/tools/wrappers.py b/src/strands_compose/tools/wrappers.py index dbee52b..4a2e6bd 100644 --- a/src/strands_compose/tools/wrappers.py +++ b/src/strands_compose/tools/wrappers.py @@ -17,7 +17,7 @@ from strands.tools.decorator import DecoratedFunctionTool, tool from strands.types.content import Message -from .extractors import extract_last_message, extract_text_from_message +from .extractors import extract_last_message, extract_text if TYPE_CHECKING: from ..types import Node @@ -75,7 +75,7 @@ def _message_to_tool_result(message: Message) -> dict[str, Any]: if content: return {"status": "success", "content": content} - return {"status": "success", "content": [{"text": extract_text_from_message(message) or ""}]} + return {"status": "success", "content": [{"text": extract_text(message)}]} def node_as_tool( diff --git a/tests/unit/config/resolvers/orchestrations/test_tools.py b/tests/unit/config/resolvers/orchestrations/test_tools.py index 276ba00..7a11c3c 100644 --- a/tests/unit/config/resolvers/orchestrations/test_tools.py +++ b/tests/unit/config/resolvers/orchestrations/test_tools.py @@ -17,10 +17,11 @@ from strands_compose.tools import ( node_as_async_tool, node_as_tool, + serialize_multiagent_result, ) from strands_compose.tools.extractors import ( extract_last_message, - extract_text_from_message, + extract_text, resolve_last_node_id, ) @@ -112,35 +113,35 @@ def _fake_graph_nodes(*names: str) -> list[Any]: # =========================================================================== -# extract_text_from_message +# extract_text # =========================================================================== -class TestExtractTextFromMessage: - """Unit tests for extract_text_from_message.""" +class TestExtractText: + """Unit tests for extract_text.""" def test_returns_last_text_block(self) -> None: """Multiple text blocks in content returns the last one.""" msg = _msg([_text_block("first"), _text_block("second")]) - assert extract_text_from_message(msg) == "second" + assert extract_text(msg) == "second" - def test_returns_none_for_no_text_blocks(self) -> None: - """Content with only toolUse blocks returns None.""" + def test_returns_empty_string_for_no_text_blocks(self) -> None: + """Content with only toolUse blocks returns an empty string.""" msg = _msg([_tool_use_block()]) - assert extract_text_from_message(msg) is None + assert extract_text(msg) == "" - def test_returns_none_for_empty_content(self) -> None: - """Empty content list returns None.""" - assert extract_text_from_message(_msg([])) is None + def test_returns_empty_string_for_empty_content(self) -> None: + """Empty content list returns an empty string.""" + assert extract_text(_msg([])) == "" - def test_returns_none_for_none_message(self) -> None: - """None message returns None.""" - assert extract_text_from_message(None) is None + def test_returns_empty_string_for_none_message(self) -> None: + """None message returns an empty string.""" + assert extract_text(None) == "" def test_skips_non_dict_blocks(self) -> None: """Non-dict items in content are safely skipped.""" msg = cast(Message, {"role": "assistant", "content": ["raw string", _text_block("ok")]}) - assert extract_text_from_message(msg) == "ok" + assert extract_text(msg) == "ok" # =========================================================================== @@ -244,7 +245,7 @@ def test_empty_node_history_falls_back_to_reverse_scan(self) -> None: def test_empty_results_returns_descriptive_fallback(self) -> None: """SwarmResult with no node results returns a descriptive text message.""" swarm_result = SwarmResult(status=Status.COMPLETED, results={}, node_history=[]) - text = extract_text_from_message(extract_last_message(swarm_result)) + text = extract_text(extract_last_message(swarm_result)) assert text is not None and "no message output" in text def test_last_node_not_in_results_falls_back_to_reverse_scan(self) -> None: @@ -318,7 +319,7 @@ def test_exception_result_returns_error_message(self) -> None: """NodeResult wrapping an Exception returns a descriptive error string.""" node_result = _node_result(RuntimeError("something broke")) message = extract_last_message(node_result) - text = extract_text_from_message(message) + text = extract_text(message) assert text is not None assert "something broke" in text @@ -631,3 +632,201 @@ async def fake_invoke_async(query: str) -> GraphResult: tool = node_as_async_tool(multi, name="graph_orch", description="Graph") assert await tool("q") == _tool_result([_text_block("graph async final")]) + + +# =========================================================================== +# serialize_multiagent_result +# =========================================================================== + + +@dataclass +class _FakeGraphEdgeObj: + """Edge represented as a GraphEdge-like object with from_node / to_node attrs.""" + + from_node: _FakeGraphNode + to_node: _FakeGraphNode + + +class TestSerializeMultiagentResult: + """Unit tests for serialize_multiagent_result.""" + + # -- SwarmResult --------------------------------------------------------- + + def test_swarm_includes_last_node_id(self) -> None: + """last_node_id is the final entry in node_history.""" + result = SwarmResult( + status=Status.COMPLETED, + results={"a": _node_result(_agent_result_with_text("a text"))}, + node_history=_fake_swarm_nodes("a"), + ) + data = serialize_multiagent_result(result) + assert data["last_node_id"] == "a" + + def test_swarm_includes_response_text(self) -> None: + """response is the plain-text answer from the last node.""" + result = SwarmResult( + status=Status.COMPLETED, + results={"lead": _node_result(_agent_result_with_text("approved"))}, + node_history=_fake_swarm_nodes("lead"), + ) + data = serialize_multiagent_result(result) + assert data["response"] == "approved" + + def test_swarm_node_history_preserves_order_and_repeats(self) -> None: + """swarm.node_history captures execution order including repeated visits.""" + result = SwarmResult( + status=Status.COMPLETED, + results={ + "drafter": _node_result(_agent_result_with_text("draft")), + "reviewer": _node_result(_agent_result_with_text("review")), + "lead": _node_result(_agent_result_with_text("final")), + }, + node_history=_fake_swarm_nodes("drafter", "lead", "reviewer", "lead", "lead"), + ) + data = serialize_multiagent_result(result) + assert data["swarm"]["node_history"] == ["drafter", "lead", "reviewer", "lead", "lead"] + + def test_swarm_last_node_id_from_history_not_dict_order(self) -> None: + """last_node_id uses node_history, not results dict insertion order.""" + # results dict insertion order ends with "reviewer", but last in history is "lead" + result = SwarmResult( + status=Status.COMPLETED, + results={ + "drafter": _node_result(_agent_result_with_text("draft")), + "lead": _node_result(_agent_result_with_text("APPROVED")), + "reviewer": _node_result(_agent_result_with_text("looks good")), + }, + node_history=_fake_swarm_nodes("drafter", "lead", "reviewer", "lead"), + ) + data = serialize_multiagent_result(result) + assert data["last_node_id"] == "lead" + assert data["response"] == "APPROVED" + + def test_swarm_no_graph_section(self) -> None: + """SwarmResult serialization does not produce a graph section.""" + result = SwarmResult( + status=Status.COMPLETED, + results={"a": _node_result(_agent_result_with_text("x"))}, + node_history=_fake_swarm_nodes("a"), + ) + data = serialize_multiagent_result(result) + assert "graph" not in data + + def test_swarm_includes_base_to_dict_fields(self) -> None: + """Output includes all standard MultiAgentResult.to_dict() fields.""" + result = SwarmResult( + status=Status.COMPLETED, + results={"a": _node_result(_agent_result_with_text("x"))}, + node_history=_fake_swarm_nodes("a"), + ) + data = serialize_multiagent_result(result) + for key in ("type", "status", "results", "execution_count", "execution_time"): + assert key in data + + # -- GraphResult --------------------------------------------------------- + + def test_graph_includes_last_node_id(self) -> None: + """last_node_id is the final entry in execution_order.""" + result = GraphResult( + status=Status.COMPLETED, + results={"writer": _node_result(_agent_result_with_text("written"))}, + execution_order=_fake_graph_nodes("fetcher", "writer"), + ) + data = serialize_multiagent_result(result) + assert data["last_node_id"] == "writer" + + def test_graph_includes_response_text(self) -> None: + """response is the plain-text answer from the last execution_order node.""" + result = GraphResult( + status=Status.COMPLETED, + results={"writer": _node_result(_agent_result_with_text("final output"))}, + execution_order=_fake_graph_nodes("fetcher", "writer"), + ) + data = serialize_multiagent_result(result) + assert data["response"] == "final output" + + def test_graph_execution_order_preserved(self) -> None: + """graph.execution_order lists node ids in execution sequence.""" + result = GraphResult( + status=Status.COMPLETED, + results={"c": _node_result(_agent_result_with_text("c"))}, + execution_order=_fake_graph_nodes("a", "b", "c"), + ) + data = serialize_multiagent_result(result) + assert data["graph"]["execution_order"] == ["a", "b", "c"] + + def test_graph_edges_as_tuples(self) -> None: + """graph.edges serializes tuple-based edges as [from, to] pairs.""" + n1, n2 = _FakeGraphNode("n1"), _FakeGraphNode("n2") + result = GraphResult( + status=Status.COMPLETED, + results={"n2": _node_result(_agent_result_with_text("out"))}, + execution_order=cast(Any, [n1, n2]), + edges=cast(Any, [(n1, n2)]), + ) + data = serialize_multiagent_result(result) + assert data["graph"]["edges"] == [["n1", "n2"]] + + def test_graph_edges_as_objects(self) -> None: + """graph.edges serializes GraphEdge-like objects via from_node/to_node.""" + n1, n2 = _FakeGraphNode("src"), _FakeGraphNode("dst") + edge = _FakeGraphEdgeObj(from_node=n1, to_node=n2) + result = GraphResult( + status=Status.COMPLETED, + results={"dst": _node_result(_agent_result_with_text("done"))}, + execution_order=cast(Any, [n1, n2]), + edges=cast(Any, [edge]), + ) + data = serialize_multiagent_result(result) + assert data["graph"]["edges"] == [["src", "dst"]] + + def test_graph_entry_points(self) -> None: + """graph.entry_points lists entry node ids.""" + entry = _FakeGraphNode("start") + result = GraphResult( + status=Status.COMPLETED, + results={"start": _node_result(_agent_result_with_text("go"))}, + execution_order=cast(Any, [entry]), + entry_points=cast(Any, [entry]), + ) + data = serialize_multiagent_result(result) + assert data["graph"]["entry_points"] == ["start"] + + def test_graph_node_counts(self) -> None: + """graph section includes completed, failed, and interrupted node counts.""" + result = GraphResult( + status=Status.COMPLETED, + results={"a": _node_result(_agent_result_with_text("x"))}, + execution_order=_fake_graph_nodes("a"), + completed_nodes=3, + failed_nodes=1, + interrupted_nodes=0, + ) + data = serialize_multiagent_result(result) + assert data["graph"]["completed_nodes"] == 3 + assert data["graph"]["failed_nodes"] == 1 + assert data["graph"]["interrupted_nodes"] == 0 + + def test_graph_no_swarm_section(self) -> None: + """GraphResult serialization does not produce a swarm section.""" + result = GraphResult( + status=Status.COMPLETED, + results={"a": _node_result(_agent_result_with_text("x"))}, + execution_order=_fake_graph_nodes("a"), + ) + data = serialize_multiagent_result(result) + assert "swarm" not in data + + # -- Base MultiAgentResult ----------------------------------------------- + + def test_base_result_no_swarm_or_graph_section(self) -> None: + """Plain MultiAgentResult produces neither swarm nor graph section.""" + result = MultiAgentResult( + status=Status.COMPLETED, + results={"a": _node_result(_agent_result_with_text("plain"))}, + ) + data = serialize_multiagent_result(result) + assert "swarm" not in data + assert "graph" not in data + assert data["last_node_id"] is None + assert data["response"] == "plain" diff --git a/uv.lock b/uv.lock index 5fb4db7..25476fa 100644 --- a/uv.lock +++ b/uv.lock @@ -1802,7 +1802,7 @@ openai = [ [[package]] name = "strands-compose" -version = "0.7.0" +version = "0.8.0" source = { editable = "." } dependencies = [ { name = "mcp" },