From 75cce10001401fae4ddb81207871b2597a35c45d Mon Sep 17 00:00:00 2001 From: Changyi Yang Date: Thu, 18 Jun 2026 07:50:10 +0000 Subject: [PATCH 1/5] feat(gateway): prefix trie data structures + MessageKey (M1) Add a per-session prefix trie to replace the single linear active trajectory (issue #51): every committed assistant turn becomes a node that may carry a checkpoint, and an incoming request longest-prefix matches any path and continues from the nearest checkpoint. - session/trie.py: MessageKey (+canonicalize_content/make_message_key), BranchCheckpoint, TrieNode, PrefixTrie (prepare/commit/finalize-traversal). MessageKey excludes reasoning_content (M1) and hashes multimodal content via a stable image/video digest (data:/long strings are hashed for compactness). - session/types.py: home TrajectoryBuffer here so trie stays free of the verl-importing codec; re-exported via session/__init__. - CPU tests cover the RFC appendix fork types (sequential / system-split / condensation / best-of-N + idempotent retry / warm-start) plus pending-node, clone-isolation, export, multimodal-digest, and regression guards that fail on the pre-review behavior. Co-Authored-By: Claude Opus 4.8 --- tests/uni_agent/gateway/test_trie_on_cpu.py | 395 +++++++++++++++ uni_agent/gateway/session/__init__.py | 4 +- uni_agent/gateway/session/trie.py | 501 ++++++++++++++++++++ uni_agent/gateway/session/types.py | 22 + 4 files changed, 920 insertions(+), 2 deletions(-) create mode 100755 tests/uni_agent/gateway/test_trie_on_cpu.py create mode 100755 uni_agent/gateway/session/trie.py diff --git a/tests/uni_agent/gateway/test_trie_on_cpu.py b/tests/uni_agent/gateway/test_trie_on_cpu.py new file mode 100755 index 00000000..9e48a9ed --- /dev/null +++ b/tests/uni_agent/gateway/test_trie_on_cpu.py @@ -0,0 +1,395 @@ +"""CPU-only unit tests for the gateway prefix trie (issue #51, M1). + +These tests exercise the trie's structural behavior directly with synthetic +token buffers — no tokenizer, model, or Ray actor — so they run on CPU. They +cover the fork types from the RFC appendix: + +- A. sequential extension (single chain, prefix reuse) +- B. system-keyed split (parallel sub-agents) +- C. context condensation (sibling branch off a fork point) +- D. best-of-N / idempotent retry +- E. warm-start (seeded no-checkpoint nodes) + +plus the core mechanics: pending nodes, checkpoint cloning, MessageKey +canonicalization, and multimodal digest keying. +""" + +from __future__ import annotations + +import pytest + +from uni_agent.gateway.session.trie import ( + PrefixTrie, + canonicalize_content, + make_message_key, +) +from uni_agent.gateway.session.types import TrajectoryBuffer + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def sys_msg(text="you are a coding agent"): + return {"role": "system", "content": text} + + +def user_msg(text): + return {"role": "user", "content": text} + + +def asst_msg(text, tool_calls=None): + msg = {"role": "assistant", "content": text} + if tool_calls is not None: + msg["tool_calls"] = tool_calls + return msg + + +def tool_msg(text, tool_call_id="call0"): + return {"role": "tool", "content": text, "tool_call_id": tool_call_id} + + +def make_buffer(prompt_len, response_tokens, mask_value=1): + """A synthetic buffer standing in for a real tokenization.""" + return TrajectoryBuffer( + prompt_ids=list(range(prompt_len)), + response_ids=list(response_tokens), + response_mask=[mask_value] * len(response_tokens), + response_logprobs=[0.0] * len(response_tokens), + ) + + +def commit_turn(trie, messages, assistant_msg, buffer, **kw): + """Run one prepare -> commit cycle and return (prepare_result, node).""" + pr = trie.prepare(messages) + node = trie.commit(pr.branch_handle, buffer, assistant_msg, messages=messages, **kw) + return pr, node + + +# --------------------------------------------------------------------------- +# MessageKey canonicalization +# --------------------------------------------------------------------------- + + +def test_message_key_is_hashable_and_stable(): + k1 = make_message_key(user_msg("hi")) + k2 = make_message_key(user_msg("hi")) + assert k1 == k2 + assert hash(k1) == hash(k2) + assert k1 != make_message_key(user_msg("bye")) + + +def test_tool_call_argument_canonicalization_matches(): + # Same logical tool call, different JSON string formatting -> same key. + a = asst_msg("", tool_calls=[{"id": "c1", "function": {"name": "f", "arguments": '{"a": 1, "b": 2}'}}]) + b = asst_msg("", tool_calls=[{"id": "c1", "function": {"name": "f", "arguments": '{"b":2,"a":1}'}}]) + assert make_message_key(a) == make_message_key(b) + + +def test_reasoning_content_excluded_from_key(): + a = {"role": "assistant", "content": "answer", "reasoning_content": "think A"} + b = {"role": "assistant", "content": "answer", "reasoning_content": "think B"} + assert make_message_key(a) == make_message_key(b) + + +def test_multimodal_content_digest_keying(): + img_a = user_msg([{"type": "image_url", "image_url": {"url": "http://x/a.png"}}, {"type": "text", "text": "what"}]) + img_a2 = user_msg([{"type": "image_url", "image_url": {"url": "http://x/a.png"}}, {"type": "text", "text": "what"}]) + img_b = user_msg([{"type": "image_url", "image_url": {"url": "http://x/b.png"}}, {"type": "text", "text": "what"}]) + assert make_message_key(img_a) == make_message_key(img_a2) + assert make_message_key(img_a) != make_message_key(img_b) + + +def test_canonicalize_content_shapes(): + assert canonicalize_content("plain") == "plain" + assert canonicalize_content(None) is None + parts = canonicalize_content([{"type": "text", "text": "hi"}]) + assert parts == (("text", "hi"),) + assert isinstance(make_message_key(user_msg([{"type": "text", "text": "hi"}])).content, tuple) + + +# --------------------------------------------------------------------------- +# A. sequential extension — single chain, prefix reuse +# --------------------------------------------------------------------------- + + +def test_sequential_extension_single_chain(): + trie = PrefixTrie() + msgs = [sys_msg(), user_msg("fix bug")] + + # turn 1: first call, no checkpoint -> full encode path + pr1 = trie.prepare(msgs) + assert pr1.trajectory_buffer is None + assert pr1.checkpoint_messages == [] + a1 = asst_msg("looking", tool_calls=[{"id": "c1", "function": {"name": "cat", "arguments": "{}"}}]) + trie.commit(pr1.branch_handle, make_buffer(100, [200, 201, 202]), a1, messages=msgs) + + # turn 2: append tool result; must match prefix and clone a1's checkpoint + msgs2 = msgs + [a1, tool_msg("def login(): ...")] + pr2 = trie.prepare(msgs2) + assert pr2.trajectory_buffer is not None, "turn 2 should reuse turn 1 checkpoint" + # checkpoint covers up to and including a1 + assert pr2.checkpoint_messages == [sys_msg(), user_msg("fix bug"), a1] + a2 = asst_msg("done") + trie.commit(pr2.branch_handle, make_buffer(100, [200, 201, 202, 50, 51, 300]), a2, messages=msgs2) + + # single linear chain, one exportable branch + assert trie.num_branches() == 1 + assert trie.num_inflight() == 0 + + +def test_prepare_returns_independent_clone(): + trie = PrefixTrie() + msgs = [sys_msg(), user_msg("q")] + a1 = asst_msg("a") + trie.commit(trie.prepare(msgs).branch_handle, make_buffer(10, [1, 2, 3]), a1, messages=msgs) + + pr = trie.prepare(msgs + [a1, tool_msg("obs")]) + pr.trajectory_buffer.response_ids.append(999) # mutate the clone + # stored checkpoint must be unaffected + node, _ = trie.match(msgs + [a1]) + assert 999 not in node.checkpoint.trajectory_buffer.response_ids + + +# --------------------------------------------------------------------------- +# B. system-keyed split — parallel sub-agents +# --------------------------------------------------------------------------- + + +def test_system_keyed_split_parallel_subagents(): + trie = PrefixTrie() + planner = [sys_msg("you are the planner"), user_msg("task")] + worker = [sys_msg("you are the worker"), user_msg("task")] + + commit_turn(trie, planner, asst_msg("plan"), make_buffer(80, [1, 2])) + commit_turn(trie, worker, asst_msg("work"), make_buffer(80, [3, 4])) + + # different system prompts diverge right at the root + assert len(trie.root.children) == 2 + assert trie.num_branches() == 2 + + +# --------------------------------------------------------------------------- +# C. context condensation — sibling branch off a fork point +# --------------------------------------------------------------------------- + + +def test_context_condensation_creates_sibling(): + trie = PrefixTrie() + base = [sys_msg(), user_msg("long task")] + a1 = asst_msg("step 1") + commit_turn(trie, base, a1, make_buffer(120, [10, 11, 12])) + + # original continuation + cont = base + [a1, user_msg("continue")] + commit_turn(trie, cont, asst_msg("step 2"), make_buffer(120, [10, 11, 12, 20, 21])) + + # condensed continuation: a *different* user message at the same fork point + condensed = base + [a1, user_msg("[recap] do step 2")] + pr = trie.prepare(condensed) + # nearest checkpoint is a1 (the shared prefix), cloned at the splice + assert pr.trajectory_buffer is not None + assert pr.checkpoint_messages == base + [a1] + commit_turn(trie, condensed, asst_msg("step 2 alt"), make_buffer(120, [10, 11, 12, 30, 31])) + + # a1 now has two user children (original + condensed) -> two branches + a1_node, _ = trie.match(base + [a1]) + assert len(a1_node.children) == 2 + assert trie.num_branches() == 2 + + +# --------------------------------------------------------------------------- +# D. best-of-N and idempotent retry +# --------------------------------------------------------------------------- + + +def test_best_of_n_siblings_share_pending_parent(): + trie = PrefixTrie() + msgs = [sys_msg(), user_msg("solve")] + + # three samples of the *same* request -> three prepare calls share the + # pending user node, then commit three different assistant children. + handles = [trie.prepare(msgs).branch_handle for _ in range(3)] + for i, h in enumerate(handles): + trie.commit(h, make_buffer(60, [i, i + 1]), asst_msg(f"answer {i}"), messages=msgs) + + user_node, _ = trie.match(msgs) + assert len(user_node.children) == 3, "three distinct assistant siblings" + assert trie.num_branches() == 3 + assert trie.num_inflight() == 0 + + +def test_idempotent_retry_reuses_node(): + trie = PrefixTrie() + msgs = [sys_msg(), user_msg("solve")] + + trie.commit(trie.prepare(msgs).branch_handle, make_buffer(60, [1, 2]), asst_msg("same"), messages=msgs) + trie.commit(trie.prepare(msgs).branch_handle, make_buffer(60, [1, 2]), asst_msg("same"), messages=msgs) + + user_node, _ = trie.match(msgs) + assert len(user_node.children) == 1, "identical output reuses the node" + assert trie.num_branches() == 1 + + +# --------------------------------------------------------------------------- +# E. warm-start — seeded no-checkpoint nodes +# --------------------------------------------------------------------------- + + +def test_warm_start_seeded_nodes_have_no_checkpoint(): + trie = PrefixTrie() + history = [sys_msg(), user_msg("imported"), asst_msg("imported reply"), user_msg("now continue")] + # seed the imported transcript as structural nodes (no checkpoints) + seeded = trie.materialize_prompt_suffix(trie.root, history, 0) + _, ckpt = trie.nearest_ckpt(seeded) + assert ckpt is None, "warm-start nodes carry no checkpoint" + + # first live generation from the seeded tail must full-encode (no clone) + pr = trie.prepare(history) + assert pr.trajectory_buffer is None + assert pr.checkpoint_messages == [] + trie.commit(pr.branch_handle, make_buffer(200, [1, 2, 3]), asst_msg("live reply"), messages=history) + + # after the first commit there is a checkpoint to reuse + assert trie.num_branches() == 1 + + +# --------------------------------------------------------------------------- +# pending nodes / cleanup / export +# --------------------------------------------------------------------------- + + +def test_pending_node_not_exported_until_commit(): + trie = PrefixTrie() + msgs = [sys_msg(), user_msg("q")] + pr = trie.prepare(msgs) + # a pending user node now exists but no assistant committed beneath it + assert trie.num_branches() == 0 + assert trie.num_inflight() == 1 + + # generation failed -> abandon; node stays but is never exported + trie.abandon(pr.branch_handle) + assert trie.num_inflight() == 0 + assert trie.num_branches() == 0 + assert list(trie.iter_export_nodes()) == [] + + +def test_export_emits_terminal_checkpoints_only(): + trie = PrefixTrie() + msgs = [sys_msg(), user_msg("q")] + a1 = asst_msg("a1", tool_calls=[{"id": "c1", "function": {"name": "f", "arguments": "{}"}}]) + commit_turn(trie, msgs, a1, make_buffer(50, [1])) + # continue the branch + msgs2 = msgs + [a1, tool_msg("obs")] + commit_turn(trie, msgs2, asst_msg("a2"), make_buffer(50, [1, 2, 3])) + + terminal = list(trie.iter_export_nodes()) + all_nodes = list(trie.iter_export_nodes(export_all=True)) + # default: only the deepest (a2). export_all: both a1 and a2. + assert len(terminal) == 1 + assert len(all_nodes) == 2 + + +def test_commit_with_unknown_handle_raises(): + trie = PrefixTrie() + pr = trie.prepare([sys_msg(), user_msg("q")]) + trie.commit(pr.branch_handle, make_buffer(10, [1]), asst_msg("a"), messages=[]) + with pytest.raises(KeyError): + trie.commit(pr.branch_handle, make_buffer(10, [1]), asst_msg("a"), messages=[]) + + +# --------------------------------------------------------------------------- +# multimodal per-branch reconstruction +# --------------------------------------------------------------------------- + + +def test_multimodal_collected_per_branch(): + trie = PrefixTrie() + msgs = [sys_msg(), user_msg([{"type": "image_url", "image_url": {"url": "http://x/a.png"}}])] + pr = trie.prepare(msgs) + trie.commit( + pr.branch_handle, + make_buffer(300, [1, 2]), + asst_msg("i see a cat"), + messages=msgs, + image_data=[""], + ) + node, _ = trie.match(msgs) + assert node.children, "assistant child committed" + asst_node = next(iter(node.children.values())) + images, videos = trie.collect_multi_modal(asst_node) + assert images == [""] + assert videos is None + + +# --------------------------------------------------------------------------- +# regression guards (these fail on the pre-review buggy behavior) +# --------------------------------------------------------------------------- + + +def test_rebuild_messages_isolates_caller_mutation(): + """rebuild_messages must not hand out the trie's own message dicts; mutating + the result must not corrupt the stored node (fails if references leak).""" + trie = PrefixTrie() + msgs = [sys_msg(), user_msg("orig")] + a1 = asst_msg("a") + commit_turn(trie, msgs, a1, make_buffer(10, [1])) + user_node, _ = trie.match(msgs) + + rebuilt = trie.rebuild_messages(user_node) + rebuilt[0]["role"] = "HACKED" # caller mutates the returned transcript + assert user_node.parent.message["role"] == "system", "trie node must be unaffected" + + +def test_media_digest_hashes_large_data_uri(): + """A long data: URI must be hashed (compact, fixed-width), not stored + verbatim in the MessageKey (fails if the raw string is kept).""" + big = "data:image/png;base64," + "A" * 400 + key = make_message_key(user_msg([{"type": "image_url", "image_url": {"url": big}}])) + _, digest = key.content[0] + assert digest != big + assert len(digest) == 16 # sha256[:16] + # identical URI -> identical key; different URI -> different key + big2 = "data:image/png;base64," + "B" * 400 + same = make_message_key(user_msg([{"type": "image_url", "image_url": {"url": big}}])) + other = make_message_key(user_msg([{"type": "image_url", "image_url": {"url": big2}}])) + assert key == same + assert key != other + + +def test_refresh_clears_stale_multimodal(): + """An idempotent refresh whose checkpoint has no multimodal data must clear + the node's stale image/video (fails if old data lingers).""" + trie = PrefixTrie() + msgs = [sys_msg(), user_msg("q")] + a = asst_msg("same answer") + + trie.commit(trie.prepare(msgs).branch_handle, make_buffer(10, [1]), a, messages=msgs, image_data=["imgA"]) + node, _ = trie.match(msgs) + asst_node = next(iter(node.children.values())) + assert asst_node.image_data == ["imgA"] + + # idempotent retry, same output, no images this time + trie.commit(trie.prepare(msgs).branch_handle, make_buffer(10, [1]), a, messages=msgs, image_data=None) + assert asst_node.image_data is None, "stale multimodal must be cleared on refresh" + + +def test_export_terminals_on_deep_branchy_trie(): + """Terminal-checkpoint selection on a deeper, branchy trie — regression + guard for the O(N) post-order rewrite of iter_export_nodes.""" + trie = PrefixTrie() + base = [sys_msg(), user_msg("root")] + a1 = asst_msg("a1", tool_calls=[{"id": "c1", "function": {"name": "f", "arguments": "{}"}}]) + commit_turn(trie, base, a1, make_buffer(10, [1])) + # continue the a1 branch one more turn (a1 becomes non-terminal) + cont = base + [a1, tool_msg("obs")] + commit_turn(trie, cont, asst_msg("a2"), make_buffer(10, [1, 2])) + # a second sibling under the same user (best-of-N), left as a leaf + commit_turn(trie, base, asst_msg("a1-alt"), make_buffer(10, [9])) + + terminals = list(trie.iter_export_nodes()) + all_ckpts = list(trie.iter_export_nodes(export_all=True)) + # terminals: a2 (deepest of the a1 branch) + a1-alt leaf -> 2 + assert len(terminals) == 2 + # all checkpoints: a1, a2, a1-alt -> 3 + assert len(all_ckpts) == 3 diff --git a/uni_agent/gateway/session/__init__.py b/uni_agent/gateway/session/__init__.py index f9c9e2b0..d7fb4703 100644 --- a/uni_agent/gateway/session/__init__.py +++ b/uni_agent/gateway/session/__init__.py @@ -8,8 +8,8 @@ from .codec import MalformedRequestError, MessageCodec from .protocol import ChatCompletionRequest, ChatCompletionResponse -from .session import GatewaySession, TrajectoryBuffer -from .types import SessionHandle, Trajectory +from .session import GatewaySession +from .types import SessionHandle, Trajectory, TrajectoryBuffer __all__ = [ "ChatCompletionRequest", diff --git a/uni_agent/gateway/session/trie.py b/uni_agent/gateway/session/trie.py new file mode 100755 index 00000000..5d9cd822 --- /dev/null +++ b/uni_agent/gateway/session/trie.py @@ -0,0 +1,501 @@ +"""Prefix trie for multi-trajectory session storage (issue #51, M1). + +The gateway historically kept a single linear ``active_trajectory`` per session, +so only the latest branch could be re-attached. This module replaces that with a +per-session **prefix trie**: every committed assistant turn becomes a node that +may carry a :class:`BranchCheckpoint`, and an incoming request longest-prefix +matches against any path and continues from the nearest checkpoint. + +Design (see ``docs/trie_m1_implementation_plan.md``): + +- Each node stores exactly one message; the full transcript is rebuilt by + walking ``parent`` pointers. Children are keyed by :class:`MessageKey`. +- A generation is ``prepare -> tokenize -> commit``. ``prepare`` walks the trie + and clones the nearest checkpoint into a request-local buffer; ``commit`` + attaches the assistant child and writes its checkpoint. +- ``prepare`` materializes the incoming prompt-side messages as *pending* + (structural) nodes that carry no checkpoint and are never exported until an + assistant child commits beneath them. + +This module is intentionally free of the verl-importing codec so its unit tests +run standalone. Tokenization stays in the codec (M1 reuses the existing +``remove_system_prompt`` path); the trie only stores token state, never encodes. +""" + +from __future__ import annotations + +import hashlib +import uuid +from dataclasses import dataclass, field +from typing import Any, Iterator + +from uni_agent.gateway.session.types import TrajectoryBuffer + +# --------------------------------------------------------------------------- +# Message canonicalization -> hashable MessageKey +# --------------------------------------------------------------------------- + + +def canonicalize_tool_arguments(arguments: Any) -> tuple[str, Any]: + """Normalize a tool call's ``arguments`` so semantically-equal arguments that + differ only in JSON string formatting (whitespace, key order) compare equal. + + Mirrors ``codec._canonicalize_tool_arguments_for_comparison`` but is kept + here so the trie has no dependency on the verl-importing codec module. + """ + if isinstance(arguments, (dict, list)): + return ("json", _freeze(arguments)) + if isinstance(arguments, str): + import json + + try: + return ("json", _freeze(json.loads(arguments))) + except json.JSONDecodeError: + return ("raw", arguments) + return ("raw", arguments) + + +def _freeze(value: Any) -> Any: + """Recursively turn dicts/lists into hashable, order-canonical tuples.""" + if isinstance(value, dict): + return tuple(sorted((k, _freeze(v)) for k, v in value.items())) + if isinstance(value, (list, tuple)): + return tuple(_freeze(v) for v in value) + return value + + +def _media_digest(value: Any) -> str: + """Stable identity digest for an image/video payload. + + Uses the URL/string verbatim when available, otherwise a short sha256 of the + bytes/representation. Only an identity fingerprint for branch routing — the + actual pixels are stored on the node, never in the key. + """ + if isinstance(value, dict): + # OpenAI-style {"url": ...} / nested {"image_url": {"url": ...}} blocks. + for key in ("url", "image_url", "video_url", "image", "video", "bytes", "data"): + if key in value: + return _media_digest(value[key]) + return hashlib.sha256(repr(_freeze(value)).encode("utf-8")).hexdigest()[:16] + if isinstance(value, str): + # Inline ``data:`` URIs (and other very long strings) are hashed so the + # MessageKey stays compact; short URLs/ids are kept verbatim. + if value.startswith("data:") or len(value) > 256: + return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] + return value + if isinstance(value, (bytes, bytearray)): + return hashlib.sha256(bytes(value)).hexdigest()[:16] + return hashlib.sha256(repr(value).encode("utf-8")).hexdigest()[:16] + + +def canonicalize_content(content: Any) -> str | tuple | None: + """Canonicalize a message ``content`` field into a hashable form. + + Text content passes through as a ``str``; multimodal content (a list of + parts) becomes a tuple of ``("text", str)`` / ``("image", digest)`` / + ``("video", digest)`` parts. Image/video payloads are reduced to a stable + digest so identical media routes to the same node. + """ + if content is None or isinstance(content, str): + return content + if not isinstance(content, list): + # Unknown scalar shape: freeze for hashability. + return ("opaque", _freeze(content)) + + parts: list[tuple[str, Any]] = [] + for item in content: + if not isinstance(item, dict): + parts.append(("text", str(item))) + continue + item_type = item.get("type") + if item_type in ("image", "image_url"): + parts.append(("image", _media_digest(item.get(item_type, item)))) + elif item_type in ("video", "video_url"): + parts.append(("video", _media_digest(item.get(item_type, item)))) + elif "text" in item: + parts.append(("text", item["text"])) + else: + parts.append(("opaque", _freeze(item))) + return tuple(parts) + + +def _canonicalize_tool_calls(tool_calls: Any) -> tuple | None: + if not isinstance(tool_calls, list): + return None + frozen: list[Any] = [] + for call in tool_calls: + if not isinstance(call, dict): + frozen.append(("raw", _freeze(call))) + continue + call_id = call.get("id") + function = call.get("function") + if isinstance(function, dict): + name = function.get("name") + args = canonicalize_tool_arguments(function.get("arguments")) + else: + name = None + args = None + frozen.append((call_id, name, args)) + return tuple(frozen) + + +@dataclass(frozen=True) +class MessageKey: + """Hashable identity of a single chat message, used as a trie child key. + + ``reasoning_content`` is intentionally excluded in M1: think blocks are + echoed back inconsistently by harnesses/templates and would cause spurious + forks (TODO: revisit for thinking-model replay). Request-level ``tools`` are + likewise not part of the key — they are gated separately at prepare time. + """ + + role: str + content: str | tuple | None + name: str | None = None + tool_calls: tuple | None = None + tool_call_id: str | None = None + + +def make_message_key(message: dict[str, Any]) -> MessageKey: + """Build a :class:`MessageKey` from an OpenAI-shaped message dict.""" + return MessageKey( + role=message.get("role", ""), + content=canonicalize_content(message.get("content")), + name=message.get("name"), + tool_calls=_canonicalize_tool_calls(message.get("tool_calls")), + tool_call_id=message.get("tool_call_id"), + ) + + +# --------------------------------------------------------------------------- +# Trie data structures +# --------------------------------------------------------------------------- + + +def clone_trajectory_buffer(buffer: TrajectoryBuffer) -> TrajectoryBuffer: + """Deep-copy the mutable list fields so each generation owns its buffer.""" + return TrajectoryBuffer( + prompt_ids=list(buffer.prompt_ids), + response_ids=list(buffer.response_ids), + response_mask=list(buffer.response_mask), + response_logprobs=list(buffer.response_logprobs), + ) + + +@dataclass +class BranchCheckpoint: + """Token-level state captured at a committed assistant prefix. + + Cloned by ``prepare`` so a new request can continue from this point without + re-encoding the shared prefix. ``image_data``/``video_data`` hold the + multimodal payload introduced by *this* node's message (per-branch storage; + decision 4B). The full branch multimodal sequence is rebuilt by walking + ancestors. + """ + + trajectory_buffer: TrajectoryBuffer + request_tools: list[dict[str, Any]] | None = None + chat_template_kwargs_key: tuple | None = None + messages: list[dict[str, Any]] | None = None + image_data: list[Any] | None = None + video_data: list[Any] | None = None + extra_fields: dict[str, Any] | None = None + + +@dataclass +class TrieNode: + """A single message in the session trie. + + Only committed assistant nodes carry a ``checkpoint``. Prompt-side nodes + materialized during ``prepare`` are structural/pending: ``checkpoint`` is + ``None`` and they are never exported until an assistant child commits. + """ + + key: MessageKey | None = None + message: dict[str, Any] | None = None + parent: TrieNode | None = None + checkpoint: BranchCheckpoint | None = None + children: dict[MessageKey, TrieNode] = field(default_factory=dict) + # Per-node multimodal introduced by this message (mirrors checkpoint copy for + # prompt-side nodes that do not yet have a checkpoint). + image_data: list[Any] | None = None + video_data: list[Any] | None = None + # In-flight generations attached beneath this node (failure-cleanup hook). + inflight: int = 0 + + @property + def is_root(self) -> bool: + return self.parent is None + + +@dataclass(frozen=True) +class BranchHandle: + """Opaque token returned by ``prepare`` and passed back to ``commit``. + + The gateway never inspects it; internally it points at the pending attach + node for this generation. + """ + + generation_id: str + + +@dataclass +class PrepareResult: + """Public contract between the trie and the gateway for one generation. + + ``trajectory_buffer`` is a request-local clone of the nearest checkpoint, or + ``None`` when no checkpoint covers the prefix (full-encode path). + ``checkpoint_messages`` is the message prefix that buffer already covers, so + the gateway only encodes ``messages[len(checkpoint_messages):]``. + """ + + trajectory_buffer: TrajectoryBuffer | None + checkpoint_messages: list[dict[str, Any]] + branch_handle: BranchHandle + image_data: list[Any] | None = None + video_data: list[Any] | None = None + request_tools: list[dict[str, Any]] | None = None + + +class PrefixTrie: + """Per-session prefix trie of chat messages with token checkpoints.""" + + def __init__(self) -> None: + self.root = TrieNode() + # Maps a live BranchHandle.generation_id to its pending attach node. + self._pending: dict[str, TrieNode] = {} + + # -- matching ----------------------------------------------------------- + + def match(self, messages: list[dict[str, Any]]) -> tuple[TrieNode, int]: + """Longest-prefix walk from the root. + + Returns the deepest matched node and the number of messages consumed. + """ + node = self.root + depth = 0 + for message in messages: + child = node.children.get(make_message_key(message)) + if child is None: + break + node = child + depth += 1 + return node, depth + + def materialize_prompt_suffix( + self, node: TrieNode, messages: list[dict[str, Any]], matched_depth: int + ) -> TrieNode: + """Find-or-create the prompt-side nodes for ``messages[matched_depth:]``. + + Newly created nodes are structural/pending (no checkpoint). Returns the + deepest node, which becomes the attach parent for the generation. + """ + attach = node + for message in messages[matched_depth:]: + key = make_message_key(message) + child = attach.children.get(key) + if child is None: + child = TrieNode(key=key, message=message, parent=attach) + attach.children[key] = child + attach = child + return attach + + @staticmethod + def nearest_ckpt(node: TrieNode) -> tuple[TrieNode | None, BranchCheckpoint | None]: + """Walk up from ``node`` (inclusive) to the nearest node with a checkpoint.""" + current: TrieNode | None = node + while current is not None: + if current.checkpoint is not None: + return current, current.checkpoint + current = current.parent + return None, None + + # -- prefix reconstruction --------------------------------------------- + + @staticmethod + def _path_to_root(node: TrieNode) -> list[TrieNode]: + chain: list[TrieNode] = [] + current: TrieNode | None = node + while current is not None and not current.is_root: + chain.append(current) + current = current.parent + chain.reverse() + return chain + + def rebuild_messages(self, node: TrieNode) -> list[dict[str, Any]]: + """Rebuild the message transcript from the root to ``node`` (inclusive). + + Returns shallow copies of the stored message dicts so a caller that + re-keys top-level fields (e.g. during normalization) cannot mutate the + trie's nodes in place. Shallow (not deep) keeps multimodal payloads from + being duplicated — nested content is never mutated by current callers. + """ + return [dict(n.message) for n in self._path_to_root(node) if n.message is not None] + + def collect_multi_modal(self, node: TrieNode) -> tuple[list[Any] | None, list[Any] | None]: + """Collect per-branch image/video data along the path to ``node``.""" + images: list[Any] = [] + videos: list[Any] = [] + for n in self._path_to_root(node): + if n.image_data: + images.extend(n.image_data) + if n.video_data: + videos.extend(n.video_data) + return (images or None, videos or None) + + # -- lifecycle ---------------------------------------------------------- + + def prepare( + self, + messages: list[dict[str, Any]], + *, + generation_id: str | None = None, + ) -> PrepareResult: + """Match the incoming messages, materialize pending nodes, and clone the + nearest checkpoint for a request-local buffer.""" + node, depth = self.match(messages) + attach_node = self.materialize_prompt_suffix(node, messages, depth) + ckpt_node, ckpt = self.nearest_ckpt(attach_node) + + generation_id = generation_id or uuid.uuid4().hex + handle = BranchHandle(generation_id=generation_id) + self._pending[generation_id] = attach_node + attach_node.inflight += 1 + + if ckpt is None: + return PrepareResult( + trajectory_buffer=None, + checkpoint_messages=[], + branch_handle=handle, + image_data=None, + video_data=None, + ) + + checkpoint_messages = ckpt.messages or self.rebuild_messages(ckpt_node) + images, videos = self.collect_multi_modal(ckpt_node) + return PrepareResult( + trajectory_buffer=clone_trajectory_buffer(ckpt.trajectory_buffer), + checkpoint_messages=list(checkpoint_messages), + branch_handle=handle, + image_data=images, + video_data=videos, + request_tools=ckpt.request_tools, + ) + + def upsert_assistant( + self, + parent: TrieNode, + assistant_msg: dict[str, Any], + checkpoint: BranchCheckpoint, + ) -> TrieNode: + """Attach (or refresh) the assistant child under ``parent``. + + Identical assistant output reuses the existing node (idempotent retry); + differing output creates a new sibling (best-of-N). + """ + key = make_message_key(assistant_msg) + child = parent.children.get(key) + if child is None: + child = TrieNode(key=key, message=assistant_msg, parent=parent) + parent.children[key] = child + else: + child.message = assistant_msg + child.checkpoint = checkpoint + # Mirror the checkpoint's multimodal payload onto the node, clearing any + # stale data so a refresh (idempotent retry) stays consistent. + child.image_data = list(checkpoint.image_data) if checkpoint.image_data is not None else None + child.video_data = list(checkpoint.video_data) if checkpoint.video_data is not None else None + return child + + def commit( + self, + branch_handle: BranchHandle, + trajectory_buffer: TrajectoryBuffer, + assistant_msg: dict[str, Any], + *, + request_tools: list[dict[str, Any]] | None = None, + chat_template_kwargs_key: tuple | None = None, + messages: list[dict[str, Any]] | None = None, + image_data: list[Any] | None = None, + video_data: list[Any] | None = None, + extra_fields: dict[str, Any] | None = None, + ) -> TrieNode: + """Resolve the pending node and attach the generated assistant child.""" + attach_node = self._pending.pop(branch_handle.generation_id, None) + if attach_node is None: + raise KeyError(f"Unknown or already-committed branch_handle: {branch_handle.generation_id}") + attach_node.inflight = max(0, attach_node.inflight - 1) + # The checkpoint lives on the assistant node and its buffer covers the + # request prompt PLUS the generated assistant turn, so the stored prefix + # must include ``assistant_msg`` (this is what the next turn's + # ``checkpoint_messages`` slices against). + covered_messages = list(messages) + [assistant_msg] if messages is not None else None + checkpoint = BranchCheckpoint( + trajectory_buffer=trajectory_buffer, + request_tools=request_tools, + chat_template_kwargs_key=chat_template_kwargs_key, + messages=covered_messages, + image_data=image_data, + video_data=video_data, + extra_fields=extra_fields, + ) + return self.upsert_assistant(attach_node, assistant_msg, checkpoint) + + def abandon(self, branch_handle: BranchHandle) -> None: + """Release a pending generation that failed before commit. + + Decrements the attach node's in-flight count. A childless structural + node is left in place (harmless; skipped by export) — M1 relies on the + finalize sweep rather than eager detach (TODO: refcount detach). + """ + attach_node = self._pending.pop(branch_handle.generation_id, None) + if attach_node is not None: + attach_node.inflight = max(0, attach_node.inflight - 1) + + # -- export ------------------------------------------------------------- + + def iter_export_nodes(self, *, export_all: bool = False) -> Iterator[TrieNode]: + """Yield assistant checkpoint nodes for finalize. + + By default emits only *terminal* assistant checkpoints (no committed + assistant descendant): rejected best-of-N siblings stay short leaves and + a continued branch is represented by its deepest checkpoint. With + ``export_all`` every committed checkpoint is emitted. + + A single post-order pass (O(N) over the trie) collects the terminals; + each node reports whether its subtree already carries a checkpoint so an + ancestor knows whether it is the deepest one. + """ + if export_all: + yield from self._iter_checkpoint_nodes(self.root) + return + _, terminals = self._collect_terminals(self.root) + yield from terminals + + def _collect_terminals(self, node: TrieNode) -> tuple[bool, list[TrieNode]]: + subtree_has_ckpt = False + terminals: list[TrieNode] = [] + for child in node.children.values(): + child_has_ckpt, child_terminals = self._collect_terminals(child) + subtree_has_ckpt = subtree_has_ckpt or child_has_ckpt + terminals.extend(child_terminals) + if node.checkpoint is not None: + if not subtree_has_ckpt: + terminals.append(node) + subtree_has_ckpt = True + return subtree_has_ckpt, terminals + + def _iter_checkpoint_nodes(self, node: TrieNode) -> Iterator[TrieNode]: + for child in node.children.values(): + if child.checkpoint is not None: + yield child + yield from self._iter_checkpoint_nodes(child) + + # -- introspection ------------------------------------------------------ + + def num_branches(self) -> int: + """Number of terminal checkpoint leaves (distinct exportable branches).""" + return sum(1 for _ in self.iter_export_nodes()) + + def num_inflight(self) -> int: + return len(self._pending) diff --git a/uni_agent/gateway/session/types.py b/uni_agent/gateway/session/types.py index 9d43d7bc..e011e6a4 100644 --- a/uni_agent/gateway/session/types.py +++ b/uni_agent/gateway/session/types.py @@ -28,6 +28,28 @@ class SessionHandle: reward_info_url: str | None = None +@dataclass +class TrajectoryBuffer: + """Mutable token buffer for a trajectory under construction. + + Lives in ``types`` so both ``session`` and ``trie`` can reference it without + a circular import (``trie`` must stay free of the verl-importing codec). + + Attributes: + prompt_ids: Prompt token IDs for the current trajectory. + response_ids: Accumulated response-side token IDs. + response_mask: Labels aligned with ``response_ids``; ``1`` for model + output and ``0`` for continuation context tokens. + response_logprobs: Log probabilities aligned with ``response_ids`` when + present; continuation context tokens use ``0.0``. + """ + + prompt_ids: list[int] + response_ids: list[int] = field(default_factory=list) + response_mask: list[int] = field(default_factory=list) + response_logprobs: list[float] = field(default_factory=list) + + @dataclass class Trajectory: """Token-level training trajectory produced when a gateway session finalizes. From 7d40dc9f693254039905c23a7f3cf842589bfafb Mon Sep 17 00:00:00 2001 From: Changyi Yang Date: Thu, 18 Jun 2026 07:50:10 +0000 Subject: [PATCH 2/5] feat(gateway): integrate prefix trie into session behind flag (M1) Wire the trie into GatewaySession end-to-end, gated by GatewayActorConfig.gateway_trie_enabled (default False; flag-off reproduces the legacy single-active behavior exactly, covered by a parity test). - session.py: trie-mode prepare routes through trie.prepare (longest-prefix match + nearest-checkpoint clone) and reuses the existing encode_full/encode_incremental paths; commit attaches the assistant child; finalize traverses the trie -> one branch-aware trajectory per terminal checkpoint. Per-branch multimodal stored on the node; failed generations abandon the pending node; snapshot_state reports num_branches/num_inflight. tools-change gate mirrors the legacy active_tool_schemas gate. - config.py/gateway.py: thread gateway_trie_enabled into sessions. - session-integration tests: flag parity, best-of-N fan-out, multi-turn reattach. Tokenization stays on remove_system_prompt (TITO -> M2); generation_lock kept (concurrency -> M3); no GC (trie dies with the session). Co-Authored-By: Claude Opus 4.8 --- .../gateway/test_session_trie_on_cpu.py | 115 ++++++++ uni_agent/gateway/config.py | 5 + uni_agent/gateway/gateway.py | 2 + uni_agent/gateway/session/session.py | 263 +++++++++++++++--- 4 files changed, 346 insertions(+), 39 deletions(-) create mode 100755 tests/uni_agent/gateway/test_session_trie_on_cpu.py diff --git a/tests/uni_agent/gateway/test_session_trie_on_cpu.py b/tests/uni_agent/gateway/test_session_trie_on_cpu.py new file mode 100755 index 00000000..f93788b7 --- /dev/null +++ b/tests/uni_agent/gateway/test_session_trie_on_cpu.py @@ -0,0 +1,115 @@ +"""CPU integration tests for the trie-backed gateway session (issue #51, M1). + +Drives a real ``GatewaySession`` (real ``MessageCodec`` + ``FakeTokenizer`` + +``SequencedBackend``) end-to-end through ``run_generation``/``complete``/ +``finalize`` to check: + +- **flag parity**: a linear multi-turn conversation produces identical backend + prompts and identical finalized trajectories whether the trie is on or off + (the M1 compatibility gate); +- **best-of-N**: repeated requests with the same messages fan out into N sibling + trajectories under the trie; +- **multi-turn reuse**: a tool/assistant continuation reattaches to its branch. + +These need the codec's runtime deps (verl tokenizer/template utils), so they +live alongside the other ``*_on_cpu`` gateway tests rather than in the +dependency-light pure-trie unit tests. +""" + +from __future__ import annotations + +import asyncio + +from tests.uni_agent.support import FakeTokenizer, SequencedBackend +from uni_agent.gateway.session.codec import MessageCodec +from uni_agent.gateway.session.session import GatewaySession +from uni_agent.gateway.session.types import SessionHandle + + +def _run(coro): + return asyncio.new_event_loop().run_until_complete(coro) + + +def _session(trie_enabled, response_length=None): + return GatewaySession( + SessionHandle(session_id="s"), + MessageCodec(FakeTokenizer()), + response_length=response_length, + trie_enabled=trie_enabled, + ) + + +SYS = {"role": "system", "content": "sys"} +USER = {"role": "user", "content": "fix bug"} + + +async def _linear_conversation(trie_enabled): + """Two-turn linear conversation; returns (backend prompts, trajectories).""" + backend = SequencedBackend(["A1", "A2"]) + session = _session(trie_enabled) + + out1 = await session.run_generation({"messages": [SYS, USER]}, backend) + a1 = {"role": "assistant", "content": out1.assistant_msg["content"]} + messages2 = [SYS, USER, a1, {"role": "user", "content": "more"}] + await session.run_generation({"messages": messages2}, backend) + + await session.set_reward_info({"score": 1.0}) + trajectories = await session.finalize() + prompts = [call["prompt_ids"] for call in backend.calls] + return prompts, trajectories + + +def test_trie_flag_parity_linear_conversation(): + off_prompts, off_trajs = _run(_linear_conversation(trie_enabled=False)) + on_prompts, on_trajs = _run(_linear_conversation(trie_enabled=True)) + + # The backend must see identical prompts on every turn. + assert on_prompts == off_prompts + + # And the finalized linear trajectory must match token-for-token. + assert len(on_trajs) == len(off_trajs) == 1 + a, b = on_trajs[0], off_trajs[0] + assert a.prompt_ids == b.prompt_ids + assert a.response_ids == b.response_ids + assert a.response_mask == b.response_mask + assert a.response_logprobs == b.response_logprobs + assert a.reward_info == b.reward_info == {"score": 1.0} + + +def test_trie_best_of_n_fans_out_to_sibling_trajectories(): + async def scenario(): + backend = SequencedBackend(["ans-A", "ans-B", "ans-C"]) + session = _session(trie_enabled=True) + for _ in range(3): + await session.run_generation({"messages": [SYS, USER]}, backend) + await session.set_reward_info({"score": 0.5}) + return await session.finalize() + + trajectories = _run(scenario()) + assert len(trajectories) == 3 + contents = sorted("".join(chr(t) for t in traj.response_ids) for traj in trajectories) + assert contents == ["ans-A", "ans-B", "ans-C"] + assert all(traj.reward_info == {"score": 0.5} for traj in trajectories) + + +def test_trie_multi_turn_reattaches_and_finalizes_single_branch(): + async def scenario(): + backend = SequencedBackend(["step1", "step2", "step3"]) + session = _session(trie_enabled=True) + msgs = [SYS, USER] + for i in range(3): + out = await session.run_generation({"messages": msgs}, backend) + msgs = msgs + [ + {"role": "assistant", "content": out.assistant_msg["content"]}, + {"role": "user", "content": f"turn {i}"}, + ] + return await session.finalize() + + trajectories = _run(scenario()) + # One continuous branch -> one terminal trajectory whose response covers all + # three generated turns plus the interstitial context tokens. + assert len(trajectories) == 1 + traj = trajectories[0] + assert len(traj.response_ids) == len(traj.response_mask) == len(traj.response_logprobs) + # mask has both generated (1) and continuation (0) tokens. + assert set(traj.response_mask) == {0, 1} diff --git a/uni_agent/gateway/config.py b/uni_agent/gateway/config.py index ccf18f5d..3454e960 100644 --- a/uni_agent/gateway/config.py +++ b/uni_agent/gateway/config.py @@ -29,6 +29,10 @@ class GatewayActorConfig: vision_info_extractor_kwargs: Static kwargs forwarded to the extractor. prompt_length: Optional prompt-token budget stored on gateway sessions. response_length: Optional response-token budget stored on gateway sessions. + gateway_trie_enabled: When True, sessions store turns in a prefix trie + (multi-branch reattachment, best-of-N, sub-agent, warm-start) instead + of a single linear active trajectory. Off by default; flag-off must + reproduce the legacy single-active behavior exactly. """ tokenizer: Any @@ -41,3 +45,4 @@ class GatewayActorConfig: vision_info_extractor_kwargs: dict[str, Any] | None = None prompt_length: int | None = None response_length: int | None = None + gateway_trie_enabled: bool = False diff --git a/uni_agent/gateway/gateway.py b/uni_agent/gateway/gateway.py index e2966e0b..8c3e0c0e 100644 --- a/uni_agent/gateway/gateway.py +++ b/uni_agent/gateway/gateway.py @@ -51,6 +51,7 @@ def __init__(self, config: GatewayActorConfig, backend): ) self._prompt_length = config.prompt_length self._response_length = config.response_length + self._trie_enabled = config.gateway_trie_enabled self._sessions: dict[str, GatewaySession] = {} self._app = FastAPI() self._server_port: int | None = None @@ -198,6 +199,7 @@ async def create_session(self, session_id: str, metadata: dict[str, Any] | None codec=self._codec, prompt_length=self._prompt_length, response_length=self._response_length, + trie_enabled=self._trie_enabled, ) return handle diff --git a/uni_agent/gateway/session/session.py b/uni_agent/gateway/session/session.py index c0a5f282..9f722673 100644 --- a/uni_agent/gateway/session/session.py +++ b/uni_agent/gateway/session/session.py @@ -4,14 +4,15 @@ import asyncio import time -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, replace from enum import Enum from typing import Any from fastapi import HTTPException from uni_agent.gateway.session.codec import MalformedRequestError, MessageCodec -from uni_agent.gateway.session.types import SessionHandle, Trajectory +from uni_agent.gateway.session.trie import BranchHandle, PrefixTrie +from uni_agent.gateway.session.types import SessionHandle, Trajectory, TrajectoryBuffer class SessionPhase(str, Enum): @@ -28,25 +29,6 @@ class SessionPhase(str, Enum): ABORTED = "ABORTED" -@dataclass -class TrajectoryBuffer: - """Mutable token buffer for the active trajectory under construction. - - Attributes: - prompt_ids: Prompt token IDs for the current trajectory. - response_ids: Accumulated response-side token IDs. - response_mask: Labels aligned with ``response_ids``; ``1`` for model - output and ``0`` for continuation context tokens. - response_logprobs: Log probabilities aligned with ``response_ids`` when - present; continuation context tokens use ``0.0``. - """ - - prompt_ids: list[int] - response_ids: list[int] = field(default_factory=list) - response_mask: list[int] = field(default_factory=list) - response_logprobs: list[float] = field(default_factory=list) - - @dataclass class EncodedData: """Session-private data prepared before backend generation. @@ -79,6 +61,12 @@ class EncodedData: video_data: list[Any] | None materialized_trajectory: Trajectory | None length_exhausted_trajectory: Trajectory | None + # Trie-mode only: opaque handle returned by ``trie.prepare`` and passed back + # to ``trie.commit``; ``new_*_data`` are the multimodal inputs introduced by + # this turn (stored on the committed node for per-branch reconstruction). + branch_handle: BranchHandle | None = None + new_image_data: list[Any] | None = None + new_video_data: list[Any] | None = None @dataclass @@ -118,10 +106,13 @@ def __init__( *, prompt_length: int | None = None, response_length: int | None = None, + trie_enabled: bool = False, ): """Create an active session bound to a handle and model codec.""" self.handle = handle self._codec = codec + self._trie_enabled = trie_enabled + self._trie: PrefixTrie | None = PrefixTrie() if trie_enabled else None self._prompt_length = prompt_length self._response_length = response_length self.active_tool_schemas: list[dict[str, Any]] | None = None @@ -162,12 +153,24 @@ async def run_generation(self, payload: dict[str, Any], backend) -> GenerationOu encoded = await self._prepare_generation_inputs(payload, request_context) if encoded.length_exhausted_trajectory is not None: empty_msg = {"role": "assistant", "content": ""} - self.trajectories.append(encoded.length_exhausted_trajectory) - self.active_trajectory = None - self.message_history = list(encoded.messages) + [empty_msg] - self.image_data = list(encoded.image_data) if encoded.image_data is not None else None - self.video_data = list(encoded.video_data) if encoded.video_data is not None else None - self.active_tool_schemas = encoded.tools + if self._trie_enabled: + self._trie.commit( + encoded.branch_handle, + encoded.buffer, + empty_msg, + request_tools=encoded.tools, + messages=encoded.messages, + image_data=encoded.new_image_data, + video_data=encoded.new_video_data, + extra_fields={"finish_reason": "length"}, + ) + else: + self.trajectories.append(encoded.length_exhausted_trajectory) + self.active_trajectory = None + self.message_history = list(encoded.messages) + [empty_msg] + self.image_data = list(encoded.image_data) if encoded.image_data is not None else None + self.video_data = list(encoded.video_data) if encoded.video_data is not None else None + self.active_tool_schemas = encoded.tools self._touch() return GenerationOutcome( assistant_msg=empty_msg, @@ -185,8 +188,10 @@ async def run_generation(self, payload: dict[str, Any], backend) -> GenerationOu video_data=encoded.video_data, ) except ValueError as e: + self._abandon_pending(encoded) raise HTTPException(status_code=400, detail=str(e)) from e except Exception as e: + self._abandon_pending(encoded) raise HTTPException(status_code=500, detail=f"{e.__class__.__name__}: {e}") from e response_ids = list(output.token_ids) @@ -207,13 +212,24 @@ async def run_generation(self, payload: dict[str, Any], backend) -> GenerationOu status_code=409, detail=f"Session {self.handle.session_id} is {self.phase.value.lower()}", ) - if encoded.materialized_trajectory is not None: - self.trajectories.append(encoded.materialized_trajectory) - self.active_trajectory = encoded.buffer - self.message_history = list(encoded.messages) + [assistant_msg] - self.image_data = list(encoded.image_data) if encoded.image_data is not None else None - self.video_data = list(encoded.video_data) if encoded.video_data is not None else None - self.active_tool_schemas = encoded.tools + if self._trie_enabled: + self._trie.commit( + encoded.branch_handle, + encoded.buffer, + assistant_msg, + request_tools=encoded.tools, + messages=encoded.messages, + image_data=encoded.new_image_data, + video_data=encoded.new_video_data, + ) + else: + if encoded.materialized_trajectory is not None: + self.trajectories.append(encoded.materialized_trajectory) + self.active_trajectory = encoded.buffer + self.message_history = list(encoded.messages) + [assistant_msg] + self.image_data = list(encoded.image_data) if encoded.image_data is not None else None + self.video_data = list(encoded.video_data) if encoded.video_data is not None else None + self.active_tool_schemas = encoded.tools self._touch() return GenerationOutcome( assistant_msg=assistant_msg, @@ -223,6 +239,8 @@ async def run_generation(self, payload: dict[str, Any], backend) -> GenerationOu ) async def _prepare_generation_inputs(self, payload: dict[str, Any], request_context: dict[str, Any]) -> EncodedData: + if self._trie_enabled: + return await self._prepare_generation_inputs_trie(payload, request_context) messages = request_context["messages"] tools = request_context["tools"] request_chat_template_kwargs = request_context["chat_template_kwargs"] @@ -315,6 +333,112 @@ async def _prepare_generation_inputs(self, payload: dict[str, Any], request_cont length_exhausted_trajectory=None, ) + async def _prepare_generation_inputs_trie( + self, payload: dict[str, Any], request_context: dict[str, Any] + ) -> EncodedData: + """Trie-backed variant of ``_prepare_generation_inputs``. + + Routes the request against the session trie (longest-prefix match across + all branches), clones the nearest checkpoint, and reuses the same codec + encode paths as the legacy flow — only the state model differs. + """ + messages = request_context["messages"] + tools = request_context["tools"] + request_chat_template_kwargs = request_context["chat_template_kwargs"] + + prepared = self._trie.prepare(messages) + # Reuse the cloned checkpoint only when its tools match this request; + # a tools change forces a full re-encode (mirrors the legacy + # ``active_tool_schemas != tools`` gate). + use_incremental = prepared.trajectory_buffer is not None and prepared.request_tools == tools + + # Multimodal introduced by *this* turn (messages beyond the checkpoint), + # stored on the committed node for per-branch reconstruction. + delta_messages = messages[len(prepared.checkpoint_messages) :] + new_image_data, new_video_data = (None, None) + if delta_messages: + new_image_data, new_video_data = await self._codec.extract_multi_modal_data(delta_messages) + + if not use_incremental: + image_data, video_data = await self._codec.extract_multi_modal_data(messages) + prompt_ids = self._codec.encode_full( + messages, + tools=tools, + image_data=image_data, + video_data=video_data, + request_chat_template_kwargs=request_chat_template_kwargs, + ) + buffer = TrajectoryBuffer(prompt_ids=prompt_ids) + # Whole prompt is freshly encoded; the node owns all of its images. + new_image_data, new_video_data = image_data, video_data + else: + buffer = prepared.trajectory_buffer + image_data = list(prepared.image_data) if prepared.image_data is not None else None + video_data = list(prepared.video_data) if prepared.video_data is not None else None + if delta_messages: + if new_image_data: + image_data = (image_data or []) + new_image_data + if new_video_data: + video_data = (video_data or []) + new_video_data + incremental_ids = self._codec.encode_incremental( + delta_messages, + image_data=new_image_data, + video_data=new_video_data, + request_chat_template_kwargs=request_chat_template_kwargs, + ) + if ( + self._response_length is not None + and len(buffer.response_mask) + len(incremental_ids) >= self._response_length + ): + context_ids = buffer.prompt_ids + buffer.response_ids + return EncodedData( + buffer=buffer, + context_ids=context_ids, + sampling_params={}, + messages=list(messages), + tools=tools, + image_data=image_data, + video_data=video_data, + materialized_trajectory=None, + length_exhausted_trajectory=self._trajectory_from_buffer( + buffer, messages, extra_fields={"finish_reason": "length"} + ), + branch_handle=prepared.branch_handle, + new_image_data=new_image_data, + new_video_data=new_video_data, + ) + buffer.response_ids.extend(incremental_ids) + buffer.response_mask.extend([0] * len(incremental_ids)) + if buffer.response_logprobs: + buffer.response_logprobs.extend([0.0] * len(incremental_ids)) + + context_ids = buffer.prompt_ids + buffer.response_ids + sampling_params = self._codec.build_sampling_params(payload) + remaining_response_budget = ( + self._response_length - len(buffer.response_mask) if self._response_length is not None else None + ) + if remaining_response_budget is not None and "max_tokens" in sampling_params: + sampling_params["max_tokens"] = min(sampling_params["max_tokens"], remaining_response_budget) + return EncodedData( + buffer=buffer, + context_ids=context_ids, + sampling_params=sampling_params, + messages=list(messages), + tools=tools, + image_data=image_data, + video_data=video_data, + materialized_trajectory=None, + length_exhausted_trajectory=None, + branch_handle=prepared.branch_handle, + new_image_data=new_image_data, + new_video_data=new_video_data, + ) + + def _abandon_pending(self, encoded: EncodedData) -> None: + """Release the trie pending node when a generation fails before commit.""" + if self._trie_enabled and encoded.branch_handle is not None: + self._trie.abandon(encoded.branch_handle) + async def set_reward_info(self, reward_info: dict[str, Any] | None = None) -> None: """Store session-level reward metadata without closing the session.""" async with self.request_lock: @@ -332,10 +456,14 @@ async def finalize(self) -> list[Trajectory]: if self.phase == SessionPhase.FINALIZED: raise RuntimeError(f"Session {self.handle.session_id} is finalized") self._touch() - self._materialize_active_trajectory() + if self._trie_enabled: + trajectories = self._materialize_trie_trajectories() + else: + self._materialize_active_trajectory() + trajectories = self.trajectories self.phase = SessionPhase.FINALIZED self._touch() - return [replace(trajectory, reward_info=dict(self.reward_info)) for trajectory in self.trajectories] + return [replace(trajectory, reward_info=dict(self.reward_info)) for trajectory in trajectories] async def abort(self) -> None: """Abort the session and prevent further generation.""" @@ -349,14 +477,28 @@ async def abort(self) -> None: def snapshot_state(self) -> dict[str, Any]: """Return a JSON-serializable snapshot for actor state inspection.""" - return { + snapshot = { "session_id": self.handle.session_id, "phase": self.phase.value, "created_at": self.created_at, "updated_at": self.updated_at, - "num_trajectories": len(self.trajectories), - "has_active_trajectory": self.active_trajectory is not None, } + if self._trie_enabled: + snapshot.update( + { + "num_branches": self._trie.num_branches(), + "num_inflight_generations": self._trie.num_inflight(), + "has_active_trajectory": self._trie.num_branches() > 0, + } + ) + else: + snapshot.update( + { + "num_trajectories": len(self.trajectories), + "has_active_trajectory": self.active_trajectory is not None, + } + ) + return snapshot def _is_request_context_prefix( self, @@ -407,6 +549,49 @@ def _build_materialized_trajectory( extra_fields=dict(extra_fields) if extra_fields else {}, ) + def _trajectory_from_buffer( + self, + buffer: TrajectoryBuffer, + messages: list[dict[str, Any]], + *, + image_data: list[Any] | None = None, + video_data: list[Any] | None = None, + extra_fields: dict[str, Any] | None = None, + ) -> Trajectory: + """Build a Trajectory from an explicit buffer + branch context. + + Trie-mode analogue of ``_build_materialized_trajectory`` that does not + read the session-global ``message_history``/``image_data``. + """ + return Trajectory( + prompt_ids=list(buffer.prompt_ids), + response_ids=list(buffer.response_ids), + response_mask=list(buffer.response_mask), + response_logprobs=list(buffer.response_logprobs) if buffer.response_logprobs else None, + reward_info={}, + num_turns=self._count_chat_turns(messages), + multi_modal_data=self._build_multi_modal_trajectory_data(image_data, video_data), + extra_fields=dict(extra_fields) if extra_fields else {}, + ) + + def _materialize_trie_trajectories(self) -> list[Trajectory]: + """Traverse the trie and emit one trajectory per terminal checkpoint.""" + trajectories: list[Trajectory] = [] + for node in self._trie.iter_export_nodes(): + checkpoint = node.checkpoint + messages = checkpoint.messages or self._trie.rebuild_messages(node) + images, videos = self._trie.collect_multi_modal(node) + trajectories.append( + self._trajectory_from_buffer( + checkpoint.trajectory_buffer, + messages, + image_data=images, + video_data=videos, + extra_fields=checkpoint.extra_fields, + ) + ) + return trajectories + def _count_chat_turns(self, message_history: list[dict[str, Any]]) -> int: return sum(1 for m in message_history if m.get("role") in ("user", "assistant")) + 1 From 54c62f4a2f558334f1cf0165ef374c6ad2b12f8b Mon Sep 17 00:00:00 2001 From: Changyi Yang Date: Thu, 18 Jun 2026 07:55:07 +0000 Subject: [PATCH 3/5] fix(gateway): address Gemini 2nd-pass review on trie - _prepare_generation_inputs_trie: abandon the pending node if encoding fails or is cancelled after trie.prepare (try/except BaseException around the encode step), preventing in-flight bookkeeping leaks. Added a session test that fails without the guard. - _media_digest: hash repr(_freeze(value)) in the scalar fallback so nested dict/list ordering is canonical across processes. Co-Authored-By: Claude Opus 4.8 --- .../gateway/test_session_trie_on_cpu.py | 23 +++++++++++++++++++ uni_agent/gateway/session/session.py | 16 ++++++++++++- uni_agent/gateway/session/trie.py | 4 +++- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/uni_agent/gateway/test_session_trie_on_cpu.py b/tests/uni_agent/gateway/test_session_trie_on_cpu.py index f93788b7..ba03a555 100755 --- a/tests/uni_agent/gateway/test_session_trie_on_cpu.py +++ b/tests/uni_agent/gateway/test_session_trie_on_cpu.py @@ -113,3 +113,26 @@ async def scenario(): assert len(traj.response_ids) == len(traj.response_mask) == len(traj.response_logprobs) # mask has both generated (1) and continuation (0) tokens. assert set(traj.response_mask) == {0, 1} + + +def test_trie_abandons_pending_node_when_encode_fails(): + """An encode failure inside input prep must not leak the trie pending node + (fails if _prepare_generation_inputs_trie doesn't abandon on error).""" + + async def scenario(): + session = _session(trie_enabled=True) + + def boom(*args, **kwargs): + raise RuntimeError("encode failed") + + session._codec.encode_full = boom # force prep to raise after prepare() + raised = False + try: + await session.run_generation({"messages": [SYS, USER]}, SequencedBackend(["x"])) + except RuntimeError: + raised = True + return raised, session._trie.num_inflight() + + raised, inflight = _run(scenario()) + assert raised + assert inflight == 0, "pending node must be abandoned on encode failure" diff --git a/uni_agent/gateway/session/session.py b/uni_agent/gateway/session/session.py index 9f722673..e88be0b2 100644 --- a/uni_agent/gateway/session/session.py +++ b/uni_agent/gateway/session/session.py @@ -343,10 +343,24 @@ async def _prepare_generation_inputs_trie( encode paths as the legacy flow — only the state model differs. """ messages = request_context["messages"] + prepared = self._trie.prepare(messages) + # ``prepare`` has registered a pending node; if encoding fails (or the + # request is cancelled) before we hand the buffer to the backend, abandon + # it so the in-flight bookkeeping does not leak. + try: + return await self._encode_prepared(payload, request_context, prepared) + except BaseException: + self._trie.abandon(prepared.branch_handle) + raise + + async def _encode_prepared( + self, payload: dict[str, Any], request_context: dict[str, Any], prepared + ) -> EncodedData: + """Encode inputs for an already-prepared trie branch (see caller).""" + messages = request_context["messages"] tools = request_context["tools"] request_chat_template_kwargs = request_context["chat_template_kwargs"] - prepared = self._trie.prepare(messages) # Reuse the cloned checkpoint only when its tools match this request; # a tools change forces a full re-encode (mirrors the legacy # ``active_tool_schemas != tools`` gate). diff --git a/uni_agent/gateway/session/trie.py b/uni_agent/gateway/session/trie.py index 5d9cd822..208108ea 100755 --- a/uni_agent/gateway/session/trie.py +++ b/uni_agent/gateway/session/trie.py @@ -85,7 +85,9 @@ def _media_digest(value: Any) -> str: return value if isinstance(value, (bytes, bytearray)): return hashlib.sha256(bytes(value)).hexdigest()[:16] - return hashlib.sha256(repr(value).encode("utf-8")).hexdigest()[:16] + # _freeze canonicalizes nested dicts/lists so the repr is order-stable + # across processes (plain repr of a dict is not guaranteed canonical). + return hashlib.sha256(repr(_freeze(value)).encode("utf-8")).hexdigest()[:16] def canonicalize_content(content: Any) -> str | tuple | None: From f4294f1cd7367873514ad57331f355c8513201ea Mon Sep 17 00:00:00 2001 From: Changyi Yang Date: Thu, 18 Jun 2026 08:06:44 +0000 Subject: [PATCH 4/5] fix(gateway): address Gemini 3rd-pass review on trie - session.py: on a full re-encode mid-branch (e.g. tools change), store only this turn's delta media on the committed node instead of the whole prompt's media, so collect_multi_modal does not double-count media carried on ancestor checkpoints. Added a regression test (multimodal + tools-change mid-branch) that fails without the fix. - trie.py: defensively shallow-copy message dicts when storing them in nodes (materialize_prompt_suffix, upsert_assistant) and in the checkpoint's covered_messages, isolating trie state from external payload mutation. Co-Authored-By: Claude Opus 4.8 --- .../gateway/test_session_trie_on_cpu.py | 42 ++++++++++++++++++- uni_agent/gateway/session/session.py | 6 ++- uni_agent/gateway/session/trie.py | 12 ++++-- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/tests/uni_agent/gateway/test_session_trie_on_cpu.py b/tests/uni_agent/gateway/test_session_trie_on_cpu.py index ba03a555..737630d5 100755 --- a/tests/uni_agent/gateway/test_session_trie_on_cpu.py +++ b/tests/uni_agent/gateway/test_session_trie_on_cpu.py @@ -20,7 +20,12 @@ import asyncio -from tests.uni_agent.support import FakeTokenizer, SequencedBackend +from tests.uni_agent.support import ( + FakeProcessor, + FakeTokenizer, + SequencedBackend, + fake_vision_info_extractor, +) from uni_agent.gateway.session.codec import MessageCodec from uni_agent.gateway.session.session import GatewaySession from uni_agent.gateway.session.types import SessionHandle @@ -136,3 +141,38 @@ def boom(*args, **kwargs): raised, inflight = _run(scenario()) assert raised assert inflight == 0, "pending node must be abandoned on encode failure" + + +def test_trie_no_duplicate_multimodal_on_full_encode_midbranch(): + """A tools change mid-branch forces a full re-encode; the committed node must + store only this turn's delta media, not the whole history, so finalize does + not double-count ancestor media (fails if the node stores the full lists).""" + + async def scenario(): + codec = MessageCodec( + FakeTokenizer(), processor=FakeProcessor(), vision_info_extractor=fake_vision_info_extractor + ) + session = GatewaySession(SessionHandle(session_id="s"), codec, trie_enabled=True) + backend = SequencedBackend(["R0", "R1"]) + img_a = {"type": "image_url", "image_url": {"url": "http://x/a.png"}} + img_b = {"type": "image_url", "image_url": {"url": "http://x/b.png"}} + + # turn 1: image A, no tools + msgs1 = [SYS, {"role": "user", "content": [img_a, {"type": "text", "text": "a"}]}] + out1 = await session.run_generation({"messages": msgs1}, backend) + # turn 2: append assistant + a new user with image B AND change tools -> + # use_incremental=False -> full re-encode mid-branch + msgs2 = msgs1 + [ + {"role": "assistant", "content": out1.assistant_msg["content"]}, + {"role": "user", "content": [img_b, {"type": "text", "text": "b"}]}, + ] + await session.run_generation( + {"messages": msgs2, "tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}]}, + backend, + ) + return await session.finalize() + + trajectories = _run(scenario()) + assert len(trajectories) == 1 + images = trajectories[0].multi_modal_data["images"] + assert images == ["http://x/a.png", "http://x/b.png"], f"no duplicate media expected, got {images}" diff --git a/uni_agent/gateway/session/session.py b/uni_agent/gateway/session/session.py index e88be0b2..ebde4037 100644 --- a/uni_agent/gateway/session/session.py +++ b/uni_agent/gateway/session/session.py @@ -383,8 +383,10 @@ async def _encode_prepared( request_chat_template_kwargs=request_chat_template_kwargs, ) buffer = TrajectoryBuffer(prompt_ids=prompt_ids) - # Whole prompt is freshly encoded; the node owns all of its images. - new_image_data, new_video_data = image_data, video_data + # The backend gets the full prompt's media (image_data/video_data), + # but the node stores only this turn's delta media (already extracted + # above from delta_messages) so collect_multi_modal does not + # double-count media carried on ancestor checkpoints. else: buffer = prepared.trajectory_buffer image_data = list(prepared.image_data) if prepared.image_data is not None else None diff --git a/uni_agent/gateway/session/trie.py b/uni_agent/gateway/session/trie.py index 208108ea..1f03195a 100755 --- a/uni_agent/gateway/session/trie.py +++ b/uni_agent/gateway/session/trie.py @@ -297,7 +297,9 @@ def materialize_prompt_suffix( key = make_message_key(message) child = attach.children.get(key) if child is None: - child = TrieNode(key=key, message=message, parent=attach) + # Shallow-copy so external mutation of the request payload cannot + # corrupt the stored node. + child = TrieNode(key=key, message=dict(message), parent=attach) attach.children[key] = child attach = child return attach @@ -397,11 +399,13 @@ def upsert_assistant( """ key = make_message_key(assistant_msg) child = parent.children.get(key) + # Shallow-copy so later mutation of the assistant message cannot corrupt + # the stored node. if child is None: - child = TrieNode(key=key, message=assistant_msg, parent=parent) + child = TrieNode(key=key, message=dict(assistant_msg), parent=parent) parent.children[key] = child else: - child.message = assistant_msg + child.message = dict(assistant_msg) child.checkpoint = checkpoint # Mirror the checkpoint's multimodal payload onto the node, clearing any # stale data so a refresh (idempotent retry) stays consistent. @@ -431,7 +435,7 @@ def commit( # request prompt PLUS the generated assistant turn, so the stored prefix # must include ``assistant_msg`` (this is what the next turn's # ``checkpoint_messages`` slices against). - covered_messages = list(messages) + [assistant_msg] if messages is not None else None + covered_messages = [dict(m) for m in messages] + [dict(assistant_msg)] if messages is not None else None checkpoint = BranchCheckpoint( trajectory_buffer=trajectory_buffer, request_tools=request_tools, From 7f9f42ec69cafa1dbc27d5c9a1c7f0cec73afeab Mon Sep 17 00:00:00 2001 From: Changyi Yang Date: Thu, 18 Jun 2026 08:17:19 +0000 Subject: [PATCH 5/5] fix(gateway): address Gemini 4th-pass review on trie - session.py: catch BaseException around backend.generate so a cancellation (asyncio.CancelledError, which is not an Exception) still abandons the trie pending node before propagating. Added a regression test that fails without the guard. - trie.py: move 'import json' to module top (PEP 8) instead of inline. Co-Authored-By: Changyi Yang Co-Authored-By: Claude Opus 4.8 --- .../gateway/test_session_trie_on_cpu.py | 23 +++++++++++++++++++ uni_agent/gateway/session/session.py | 5 ++++ uni_agent/gateway/session/trie.py | 3 +-- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/tests/uni_agent/gateway/test_session_trie_on_cpu.py b/tests/uni_agent/gateway/test_session_trie_on_cpu.py index 737630d5..a749aa6b 100755 --- a/tests/uni_agent/gateway/test_session_trie_on_cpu.py +++ b/tests/uni_agent/gateway/test_session_trie_on_cpu.py @@ -176,3 +176,26 @@ async def scenario(): assert len(trajectories) == 1 images = trajectories[0].multi_modal_data["images"] assert images == ["http://x/a.png", "http://x/b.png"], f"no duplicate media expected, got {images}" + + +def test_trie_abandons_pending_node_on_cancellation(): + """A cancellation during backend.generate (CancelledError is BaseException, + not Exception) must still abandon the pending node (fails if only + ValueError/Exception are caught).""" + + class CancellingBackend: + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + raise asyncio.CancelledError() + + async def scenario(): + session = _session(trie_enabled=True) + cancelled = False + try: + await session.run_generation({"messages": [SYS, USER]}, CancellingBackend()) + except asyncio.CancelledError: + cancelled = True + return cancelled, session._trie.num_inflight() + + cancelled, inflight = _run(scenario()) + assert cancelled + assert inflight == 0, "pending node must be abandoned on cancellation" diff --git a/uni_agent/gateway/session/session.py b/uni_agent/gateway/session/session.py index ebde4037..b117d77e 100644 --- a/uni_agent/gateway/session/session.py +++ b/uni_agent/gateway/session/session.py @@ -193,6 +193,11 @@ async def run_generation(self, payload: dict[str, Any], backend) -> GenerationOu except Exception as e: self._abandon_pending(encoded) raise HTTPException(status_code=500, detail=f"{e.__class__.__name__}: {e}") from e + except BaseException: + # e.g. asyncio.CancelledError (not an Exception) — still release + # the trie's pending node before propagating. + self._abandon_pending(encoded) + raise response_ids = list(output.token_ids) encoded.buffer.response_ids.extend(response_ids) diff --git a/uni_agent/gateway/session/trie.py b/uni_agent/gateway/session/trie.py index 1f03195a..046672c3 100755 --- a/uni_agent/gateway/session/trie.py +++ b/uni_agent/gateway/session/trie.py @@ -25,6 +25,7 @@ from __future__ import annotations import hashlib +import json import uuid from dataclasses import dataclass, field from typing import Any, Iterator @@ -46,8 +47,6 @@ def canonicalize_tool_arguments(arguments: Any) -> tuple[str, Any]: if isinstance(arguments, (dict, list)): return ("json", _freeze(arguments)) if isinstance(arguments, str): - import json - try: return ("json", _freeze(json.loads(arguments))) except json.JSONDecodeError: