Skip to content

[gateway] Prefix Trie for multi-trajectory storage (M1)#64

Closed
ChangyiYang wants to merge 5 commits into
verl-project:mainfrom
ChangyiYang:feat/gateway-prefix-trie
Closed

[gateway] Prefix Trie for multi-trajectory storage (M1)#64
ChangyiYang wants to merge 5 commits into
verl-project:mainfrom
ChangyiYang:feat/gateway-prefix-trie

Conversation

@ChangyiYang

@ChangyiYang ChangyiYang commented Jun 18, 2026

Copy link
Copy Markdown

Summary

Implements the per-session prefix trie that replaces the single linear active_trajectory, per RFC #51. Today a session keeps one message_history + one active trajectory, so only the latest branch can be re-attached; this generalizes it into a trie where an incoming request longest-prefix matches any path and continues from the nearest checkpoint — enabling sub-agent / best-of-N / context-condensation / warm-start.

Fully gated behind GatewayActorConfig.gateway_trie_enabled (default off); flag-off reproduces the legacy single-active behavior exactly (covered by a parity test).

What's in this PR (full M1)

Data model — gateway/trie.py (new)

  • MessageKey (+ canonicalize_content / make_message_key): hashable per-message identity; excludes reasoning_content (M1; TODO thinking-model replay) and keys multimodal content via a stable image/video digest (pixels live on the node, not the key).
  • BranchCheckpoint, TrieNode, PrefixTrie with the prepare → commit lifecycle and a finalize traversal (iter_export_nodes → terminal checkpoints only by default).
  • prepare materializes incoming prompt-side messages as pending/structural nodes (no checkpoint, never exported) and clones the nearest checkpoint into a request-local buffer.
  • TrajectoryBuffer re-homed to gateway/types.py so the trie stays free of the verl-importing codec.

Session integration — gateway/session.py

  • _prepare_generation_inputs_trie: routes through trie.prepare (longest-prefix match + nearest-checkpoint clone) and reuses the existing encode_full / encode_incremental paths — only the state model changes.
  • commit attaches the assistant child via trie.commit; finalize traverses the trie → one branch-aware Trajectory per terminal checkpoint (reward stamped per the current one-session-one-reward contract).
  • Per-branch multimodal stored on the committed node (new-this-turn images) and reconstructed along the path. Failed generations abandon the pending node. snapshot_state reports num_branches / num_inflight_generations in trie mode.
  • Tools-change gate preserved (full re-encode), mirroring the legacy active_tool_schemas gate.

Wiring: config.py / gateway.py thread gateway_trie_enabled into sessions.

Tests (CPU-only)

  • tests/uni_agent/gateway/test_trie_on_cpu.py — 16 trie-structural tests covering the RFC appendix fork types: A sequential extension, B system-keyed split (parallel sub-agents), C context condensation, D best-of-N + idempotent retry, E warm-start; plus pending-node/no-export, clone isolation, terminal-vs-all export, multimodal digest + per-branch reconstruction.
  • tests/uni_agent/gateway/test_session_trie_on_cpu.py — 3 session-integration tests: flag parity (trie-on yields identical backend prompts + identical finalized trajectory as trie-off on a linear multi-turn conversation — the compatibility gate), best-of-N fan-out (N siblings → N trajectories), multi-turn reattach (single continuous branch).

Scope / deferrals

  • Tokenization boundary: reuse current remove_system_prompt (text via tokenizer, multimodal via processor). Per-model seam handling (Qwen trailing \n, GLM 4.7 ambiguous bos/eos) + a token-sequence verifier → M2 (TITO-style merge).
  • Concurrency: generation_lock kept (serial) → M3.
  • GC/eviction: none — the trie dies with the session at finalize/abort.
  • Reward: one-session-one-reward; reward contract unchanged.
  • No new caller API — branching is purely message-driven.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the gateway session's single linear active trajectory with a prefix trie (PrefixTrie) to natively support longest-prefix reattachment across multiple branches, moving TrajectoryBuffer to types.py to prevent circular imports and adding comprehensive CPU-only unit tests. The review feedback highlights several critical improvements in uni_agent/gateway/trie.py, including returning deep copies of messages in rebuild_messages to prevent internal state mutation, hashing large string payloads in _media_digest to avoid performance issues, removing the unused _PROMPT_SIDE_ROLES constant, ensuring stale multimodal data is cleared during node updates in upsert_assistant, and optimizing the trie traversal in iter_export_nodes from $O(N^2)$ to $O(N)$ complexity.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread uni_agent/gateway/trie.py Outdated
Comment on lines +324 to +326
def rebuild_messages(self, node: TrieNode) -> list[dict[str, Any]]:
"""Rebuild the message transcript from the root to ``node`` (inclusive)."""
return [n.message for n in self._path_to_root(node) if n.message is not None]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

rebuild_messages returns references to the original message dictionaries stored in the trie nodes. If the caller mutates these dictionaries (e.g., during normalization or template rendering), it will corrupt the trie's internal state and potentially break prefix matching (since MessageKey is built from the message content).

To prevent this, we should return deep copies of the messages.

Suggested change
def rebuild_messages(self, node: TrieNode) -> list[dict[str, Any]]:
"""Rebuild the message transcript from the root to ``node`` (inclusive)."""
return [n.message for n in self._path_to_root(node) if n.message is not None]
def rebuild_messages(self, node: TrieNode) -> list[dict[str, Any]]:
"""Rebuild the message transcript from the root to ``node`` (inclusive)."""
import copy
return [copy.deepcopy(n.message) for n in self._path_to_root(node) if n.message is not None]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in d568b1b. Returning shallow copies (dict(n.message)) so a caller re-keying top-level fields can't mutate the stored node. Used shallow rather than deepcopy on purpose: the messages can carry large inline multimodal payloads and deep-copying them on every rebuild_messages would be costly, while current callers only read checkpoint_messages for length-slicing and num_turns, never mutating nested content. Note the child-dict keys are pre-computed frozen MessageKeys, so prefix matching itself isn't affected by node mutation; the real risk was transcript/checkpoint corruption, which the shallow copy guards against.

Comment on lines +85 to +86
if isinstance(value, str):
return value

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _media_digest, returning value verbatim when it is a str can lead to high memory usage and slow comparisons if the string is a very large base64-encoded data URI (which is common for inline multimodal payloads).

We should hash the string if it is a data URI or exceeds a reasonable length (e.g., 256 characters) to keep the MessageKey compact and uniform.

Suggested change
if isinstance(value, str):
return value
if isinstance(value, str):
if value.startswith("data:") or len(value) > 256:
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16]
return value

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in d568b1b_media_digest now hashes data: URIs and any string >256 chars (sha256, 16-hex), keeping short URLs/ids verbatim. Keeps the MessageKey compact and comparisons cheap for inline base64 payloads.

Comment thread uni_agent/gateway/trie.py Outdated
Comment on lines +34 to +36
# Roles whose messages are append-only prompt-side turns (never the generated
# assistant). Used to classify which incoming messages become pending nodes.
_PROMPT_SIDE_ROLES = frozenset({"system", "user", "tool"})

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The constant _PROMPT_SIDE_ROLES is defined but never used anywhere in the module. If it is no longer needed, it should be removed to keep the codebase clean.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in d568b1b — it was a leftover from an earlier draft and is no longer referenced.

Comment thread uni_agent/gateway/trie.py Outdated
Comment on lines +396 to +399
if checkpoint.image_data:
child.image_data = list(checkpoint.image_data)
if checkpoint.video_data:
child.video_data = list(checkpoint.video_data)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In upsert_assistant, when updating an existing node (e.g., during an idempotent retry or refresh), if the new checkpoint has image_data or video_data as None, the existing child.image_data and child.video_data are not cleared. This leaves stale multimodal data on the node, causing an inconsistency between the node's fields and its checkpoint.

We should explicitly set these fields to None (or copy them) if they are not present in the new checkpoint.

Suggested change
if checkpoint.image_data:
child.image_data = list(checkpoint.image_data)
if checkpoint.video_data:
child.video_data = list(checkpoint.video_data)
child.image_data = list(checkpoint.image_data) if checkpoint.image_data is not None else None
child.video_data = list(checkpoint.video_data) if checkpoint.video_data is not None else None

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in d568b1bupsert_assistant now sets both fields unconditionally (list(...) if ... is not None else None), so a refresh/idempotent retry whose checkpoint has no multimodal data clears any stale payload and keeps the node consistent with its checkpoint.

Comment thread uni_agent/gateway/trie.py Outdated
Comment on lines +447 to +459
def iter_export_nodes(self, *, export_all: bool = False) -> Iterator[TrieNode]:
"""Yield assistant checkpoint nodes for finalize.

By default emits only *terminal* assistant checkpoints (no committed
assistant descendant): rejected best-of-N siblings stay short leaves and
a continued branch is represented by its deepest checkpoint. With
``export_all`` every committed checkpoint is emitted.
"""
for node in self._iter_nodes(self.root):
if node.checkpoint is None:
continue
if export_all or not self._has_committed_descendant(node):
yield node

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of iter_export_nodes performs a full trie traversal (_iter_nodes), and for every node, it calls _has_committed_descendant which performs another traversal of all its descendants. For a trie of depth $N$, this results in an $O(N^2)$ time complexity, which can become a performance bottleneck for long conversations or large tries.

We can optimize this to $O(N)$ by performing a single bottom-up (post-order) traversal to collect the terminal checkpoints.

    def iter_export_nodes(self, *, export_all: bool = False) -> Iterator[TrieNode]:
        """Yield assistant checkpoint nodes for finalize.

        By default emits only *terminal* assistant checkpoints (no committed
        assistant descendant): rejected best-of-N siblings stay short leaves and
        a continued branch is represented by its deepest checkpoint. With
        ``export_all`` every committed checkpoint is emitted.
        """
        if export_all:
            for node in self._iter_nodes(self.root):
                if node.checkpoint is not None:
                    yield node
            return

        def _find_terminals(node: TrieNode) -> tuple[bool, list[TrieNode]]:
            has_ckpt = False
            terminals = []
            for child in node.children.values():
                child_has_ckpt, child_terminals = _find_terminals(child)
                has_ckpt = has_ckpt or child_has_ckpt
                terminals.extend(child_terminals)
            if node.checkpoint is not None:
                if not has_ckpt:
                    terminals.append(node)
                has_ckpt = True
            return has_ckpt, terminals

        _, terminal_nodes = _find_terminals(self.root)
        yield from terminal_nodes

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — fixed in d568b1b. Replaced the per-node descendant scan with a single O(N) post-order pass (_collect_terminals) that returns, per subtree, whether it already carries a checkpoint, so an ancestor knows if it's the deepest. export_all keeps a straight checkpoint-node walk. Semantics unchanged; the 19 trie/session tests still pass.

@ChangyiYang ChangyiYang changed the title [WIP][gateway] Prefix Trie for multi-trajectory storage (M1) [gateway] Prefix Trie for multi-trajectory storage (M1) Jun 18, 2026
@ChangyiYang

Copy link
Copy Markdown
Author

@codex review

@ChangyiYang ChangyiYang force-pushed the feat/gateway-prefix-trie branch 2 times, most recently from 5ff54a1 to fd9db15 Compare June 18, 2026 07:34
@ChangyiYang

Copy link
Copy Markdown
Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a prefix trie implementation (uni_agent/gateway/trie.py) to support multi-trajectory session storage, enabling features like multi-branch reattachment, best-of-N, and sub-agents. It updates the gateway session, configuration, and types to integrate this trie-backed state model while maintaining backward compatibility with the legacy linear trajectory flow. Feedback on the changes highlights two key improvements: wrapping the trie preparation logic in a try...except BaseException block to prevent pending node leaks during exceptions or cancellations, and using _freeze(value) in _media_digest to ensure stable, order-canonical hashing of arbitrary fallback structures.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread uni_agent/gateway/session/session.py Outdated
Comment on lines +352 to +438
prepared = self._trie.prepare(messages)
# Reuse the cloned checkpoint only when its tools match this request;
# a tools change forces a full re-encode (mirrors the legacy
# ``active_tool_schemas != tools`` gate).
use_incremental = prepared.trajectory_buffer is not None and prepared.request_tools == tools

# Multimodal introduced by *this* turn (messages beyond the checkpoint),
# stored on the committed node for per-branch reconstruction.
delta_messages = messages[len(prepared.checkpoint_messages) :]
new_image_data, new_video_data = (None, None)
if delta_messages:
new_image_data, new_video_data = await self._codec.extract_multi_modal_data(delta_messages)

if not use_incremental:
image_data, video_data = await self._codec.extract_multi_modal_data(messages)
prompt_ids = self._codec.encode_full(
messages,
tools=tools,
image_data=image_data,
video_data=video_data,
request_chat_template_kwargs=request_chat_template_kwargs,
)
buffer = TrajectoryBuffer(prompt_ids=prompt_ids)
# Whole prompt is freshly encoded; the node owns all of its images.
new_image_data, new_video_data = image_data, video_data
else:
buffer = prepared.trajectory_buffer
image_data = list(prepared.image_data) if prepared.image_data is not None else None
video_data = list(prepared.video_data) if prepared.video_data is not None else None
if delta_messages:
if new_image_data:
image_data = (image_data or []) + new_image_data
if new_video_data:
video_data = (video_data or []) + new_video_data
incremental_ids = self._codec.encode_incremental(
delta_messages,
image_data=new_image_data,
video_data=new_video_data,
request_chat_template_kwargs=request_chat_template_kwargs,
)
if (
self._response_length is not None
and len(buffer.response_mask) + len(incremental_ids) >= self._response_length
):
context_ids = buffer.prompt_ids + buffer.response_ids
return EncodedData(
buffer=buffer,
context_ids=context_ids,
sampling_params={},
messages=list(messages),
tools=tools,
image_data=image_data,
video_data=video_data,
materialized_trajectory=None,
length_exhausted_trajectory=self._trajectory_from_buffer(
buffer, messages, extra_fields={"finish_reason": "length"}
),
branch_handle=prepared.branch_handle,
new_image_data=new_image_data,
new_video_data=new_video_data,
)
buffer.response_ids.extend(incremental_ids)
buffer.response_mask.extend([0] * len(incremental_ids))
if buffer.response_logprobs:
buffer.response_logprobs.extend([0.0] * len(incremental_ids))

context_ids = buffer.prompt_ids + buffer.response_ids
sampling_params = self._codec.build_sampling_params(payload)
remaining_response_budget = (
self._response_length - len(buffer.response_mask) if self._response_length is not None else None
)
if remaining_response_budget is not None and "max_tokens" in sampling_params:
sampling_params["max_tokens"] = min(sampling_params["max_tokens"], remaining_response_budget)
return EncodedData(
buffer=buffer,
context_ids=context_ids,
sampling_params=sampling_params,
messages=list(messages),
tools=tools,
image_data=image_data,
video_data=video_data,
materialized_trajectory=None,
length_exhausted_trajectory=None,
branch_handle=prepared.branch_handle,
new_image_data=new_image_data,
new_video_data=new_video_data,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If any exception or cancellation occurs during incremental encoding or multimodal extraction inside _prepare_generation_inputs_trie after self._trie.prepare has been called, the pending node will be leaked. Wrapping the preparation logic in a try...except BaseException: block ensures that the pending node is properly abandoned on any failure or cancellation.

        prepared = self._trie.prepare(messages)
        try:
            # Reuse the cloned checkpoint only when its tools match this request;
            # a tools change forces a full re-encode (mirrors the legacy
            # ``active_tool_schemas != tools`` gate).
            use_incremental = prepared.trajectory_buffer is not None and prepared.request_tools == tools

            # Multimodal introduced by *this* turn (messages beyond the checkpoint),
            # stored on the committed node for per-branch reconstruction.
            delta_messages = messages[len(prepared.checkpoint_messages) :]
            new_image_data, new_video_data = (None, None)
            if delta_messages:
                new_image_data, new_video_data = await self._codec.extract_multi_modal_data(delta_messages)

            if not use_incremental:
                image_data, video_data = await self._codec.extract_multi_modal_data(messages)
                prompt_ids = self._codec.encode_full(
                    messages,
                    tools=tools,
                    image_data=image_data,
                    video_data=video_data,
                    request_chat_template_kwargs=request_chat_template_kwargs,
                )
                buffer = TrajectoryBuffer(prompt_ids=prompt_ids)
                # Whole prompt is freshly encoded; the node owns all of its images.
                new_image_data, new_video_data = image_data, video_data
            else:
                buffer = prepared.trajectory_buffer
                image_data = list(prepared.image_data) if prepared.image_data is not None else None
                video_data = list(prepared.video_data) if prepared.video_data is not None else None
                if delta_messages:
                    if new_image_data:
                        image_data = (image_data or []) + new_image_data
                    if new_video_data:
                        video_data = (video_data or []) + new_video_data
                    incremental_ids = self._codec.encode_incremental(
                        delta_messages,
                        image_data=new_image_data,
                        video_data=new_video_data,
                        request_chat_template_kwargs=request_chat_template_kwargs,
                    )
                    if (
                        self._response_length is not None
                        and len(buffer.response_mask) + len(incremental_ids) >= self._response_length
                    ):
                        context_ids = buffer.prompt_ids + buffer.response_ids
                        return EncodedData(
                            buffer=buffer,
                            context_ids=context_ids,
                            sampling_params={},
                            messages=list(messages),
                            tools=tools,
                            image_data=image_data,
                            video_data=video_data,
                            materialized_trajectory=None,
                            length_exhausted_trajectory=self._trajectory_from_buffer(
                                buffer, messages, extra_fields={"finish_reason": "length"}
                            ),
                            branch_handle=prepared.branch_handle,
                            new_image_data=new_image_data,
                            new_video_data=new_video_data,
                        )
                    buffer.response_ids.extend(incremental_ids)
                    buffer.response_mask.extend([0] * len(incremental_ids))
                    if buffer.response_logprobs:
                        buffer.response_logprobs.extend([0.0] * len(incremental_ids))

            context_ids = buffer.prompt_ids + buffer.response_ids
            sampling_params = self._codec.build_sampling_params(payload)
            remaining_response_budget = (
                self._response_length - len(buffer.response_mask) if self._response_length is not None else None
            )
            if remaining_response_budget is not None and "max_tokens" in sampling_params:
                sampling_params["max_tokens"] = min(sampling_params["max_tokens"], remaining_response_budget)
            return EncodedData(
                buffer=buffer,
                context_ids=context_ids,
                sampling_params=sampling_params,
                messages=list(messages),
                tools=tools,
                image_data=image_data,
                video_data=video_data,
                materialized_trajectory=None,
                length_exhausted_trajectory=None,
                branch_handle=prepared.branch_handle,
                new_image_data=new_image_data,
                new_video_data=new_video_data,
            )
        except BaseException:
            self._trie.abandon(prepared.branch_handle)
            raise

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in 54c62f4. _prepare_generation_inputs_trie now wraps the post-prepare encoding in try/except BaseException and calls trie.abandon(prepared.branch_handle) before re-raising, so a failure or cancellation during extract_multi_modal_data/encode_incremental no longer leaks the pending node. Added test_trie_abandons_pending_node_when_encode_fails (forces encode_full to raise, asserts num_inflight()==0); verified it fails without the guard.

Comment thread uni_agent/gateway/session/trie.py Outdated
return value
if isinstance(value, (bytes, bytearray)):
return hashlib.sha256(bytes(value)).hexdigest()[:16]
return hashlib.sha256(repr(value).encode("utf-8")).hexdigest()[:16]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using repr(value) directly on arbitrary fallback structures (such as lists of dictionaries) is unstable because dictionary key ordering in repr() is not guaranteed to be canonical across different Python processes or runs. Wrapping value in _freeze(value) recursively normalizes any nested dictionaries and lists into sorted, hashable tuples, ensuring a stable and order-canonical representation.

Suggested change
return hashlib.sha256(repr(value).encode("utf-8")).hexdigest()[:16]
return hashlib.sha256(repr(_freeze(value)).encode("utf-8")).hexdigest()[:16]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 54c62f4 — the scalar fallback now hashes repr(_freeze(value)) so nested dict/list ordering is canonicalized before hashing.

ChangyiYang and others added 2 commits June 18, 2026 07:50
Add a per-session prefix trie to replace the single linear active trajectory
(issue verl-project#51): every committed assistant turn becomes a node that may carry a
checkpoint, and an incoming request longest-prefix matches any path and
continues from the nearest checkpoint.

- session/trie.py: MessageKey (+canonicalize_content/make_message_key),
  BranchCheckpoint, TrieNode, PrefixTrie (prepare/commit/finalize-traversal).
  MessageKey excludes reasoning_content (M1) and hashes multimodal content via
  a stable image/video digest (data:/long strings are hashed for compactness).
- session/types.py: home TrajectoryBuffer here so trie stays free of the
  verl-importing codec; re-exported via session/__init__.
- CPU tests cover the RFC appendix fork types (sequential / system-split /
  condensation / best-of-N + idempotent retry / warm-start) plus pending-node,
  clone-isolation, export, multimodal-digest, and regression guards that fail
  on the pre-review behavior.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Wire the trie into GatewaySession end-to-end, gated by
GatewayActorConfig.gateway_trie_enabled (default False; flag-off reproduces the
legacy single-active behavior exactly, covered by a parity test).

- session.py: trie-mode prepare routes through trie.prepare (longest-prefix
  match + nearest-checkpoint clone) and reuses the existing
  encode_full/encode_incremental paths; commit attaches the assistant child;
  finalize traverses the trie -> one branch-aware trajectory per terminal
  checkpoint. Per-branch multimodal stored on the node; failed generations
  abandon the pending node; snapshot_state reports num_branches/num_inflight.
  tools-change gate mirrors the legacy active_tool_schemas gate.
- config.py/gateway.py: thread gateway_trie_enabled into sessions.
- session-integration tests: flag parity, best-of-N fan-out, multi-turn reattach.

Tokenization stays on remove_system_prompt (TITO -> M2); generation_lock kept
(concurrency -> M3); no GC (trie dies with the session).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@ChangyiYang ChangyiYang force-pushed the feat/gateway-prefix-trie branch from fd9db15 to 7d40dc9 Compare June 18, 2026 07:50
- _prepare_generation_inputs_trie: abandon the pending node if encoding fails
  or is cancelled after trie.prepare (try/except BaseException around the
  encode step), preventing in-flight bookkeeping leaks. Added a session test
  that fails without the guard.
- _media_digest: hash repr(_freeze(value)) in the scalar fallback so nested
  dict/list ordering is canonical across processes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@ChangyiYang

Copy link
Copy Markdown
Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a prefix trie (PrefixTrie) for multi-trajectory session storage, enabling advanced capabilities like multi-branch reattachment, best-of-N trajectories, and sub-agent workflows under a new gateway_trie_enabled configuration flag. The feedback focuses on enhancing robustness by defensively copying message dictionaries when storing them in trie nodes and checkpoints to prevent external mutations from corrupting the trie's internal state. Additionally, a high-severity bug was identified where a full encode path incorrectly overwrites the turn-level multimodal delta lists with full historical lists, which would cause duplicate entries during path reconstruction.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread uni_agent/gateway/session/session.py Outdated
Comment on lines +386 to +387
# Whole prompt is freshly encoded; the node owns all of its images.
new_image_data, new_video_data = image_data, video_data

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When use_incremental is False (e.g., due to a tools change), a full encode is performed. In this case, image_data and video_data contain the multimodal payloads for the entire conversation history. Overwriting new_image_data and new_video_data with these full lists causes the committed node to store all historical images/videos. When collect_multi_modal subsequently traverses the path to reconstruct the multimodal sequence, it will collect these images/videos from both the current node and its ancestors, leading to duplicate entries.

To prevent duplication, the node should only store the delta images/videos introduced in the current turn (which are already extracted into new_image_data and new_video_data from delta_messages). Removing the overwrite ensures that only the delta is stored on the node, while the backend still receives the full lists for the full encode.

Suggested change
# Whole prompt is freshly encoded; the node owns all of its images.
new_image_data, new_video_data = image_data, video_data
# Whole prompt is freshly encoded; the node only stores the delta images
# to avoid duplicates during path reconstruction.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch (real bug) — fixed in f4294f1. On a full re-encode mid-branch the node now stores only this turn's delta media (new_image_data/new_video_data already extracted from delta_messages); the backend still gets the full lists. Removed the delta→full overwrite. Added test_trie_no_duplicate_multimodal_on_full_encode_midbranch (image turn 1, then a tools change + new image on turn 2 forcing full re-encode) and verified it fails without the fix (yields the duplicated ancestor image).

Comment thread uni_agent/gateway/session/trie.py Outdated
key = make_message_key(message)
child = attach.children.get(key)
if child is None:
child = TrieNode(key=key, message=message, parent=attach)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent external mutations of the request payload from accidentally corrupting the trie's internal state, it is highly recommended to perform a defensive copy of the message dictionary when storing it in the TrieNode.

Suggested change
child = TrieNode(key=key, message=message, parent=attach)
child = TrieNode(key=key, message=dict(message), parent=attach)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in f4294f1dict(message) shallow copy when creating the node.

Comment thread uni_agent/gateway/session/trie.py Outdated
Comment on lines +401 to +404
child = TrieNode(key=key, message=assistant_msg, parent=parent)
parent.children[key] = child
else:
child.message = assistant_msg

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, perform a defensive copy of assistant_msg when creating or updating an assistant node to isolate the trie's internal state from any subsequent modifications to the assistant message dictionary.

Suggested change
child = TrieNode(key=key, message=assistant_msg, parent=parent)
parent.children[key] = child
else:
child.message = assistant_msg
if child is None:
child = TrieNode(key=key, message=dict(assistant_msg), parent=parent)
parent.children[key] = child
else:
child.message = dict(assistant_msg)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in f4294f1dict(assistant_msg) on both create and refresh.

Comment thread uni_agent/gateway/session/trie.py Outdated
# request prompt PLUS the generated assistant turn, so the stored prefix
# must include ``assistant_msg`` (this is what the next turn's
# ``checkpoint_messages`` slices against).
covered_messages = list(messages) + [assistant_msg] if messages is not None else None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defensively copy the messages in covered_messages to ensure that the checkpoint's stored message history is completely isolated from any external mutations.

Suggested change
covered_messages = list(messages) + [assistant_msg] if messages is not None else None
covered_messages = [dict(m) for m in messages] + [dict(assistant_msg)] if messages is not None else None

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in f4294f1covered_messages now stores [dict(m) for m in messages] + [dict(assistant_msg)].

- session.py: on a full re-encode mid-branch (e.g. tools change), store only
  this turn's delta media on the committed node instead of the whole prompt's
  media, so collect_multi_modal does not double-count media carried on ancestor
  checkpoints. Added a regression test (multimodal + tools-change mid-branch)
  that fails without the fix.
- trie.py: defensively shallow-copy message dicts when storing them in nodes
  (materialize_prompt_suffix, upsert_assistant) and in the checkpoint's
  covered_messages, isolating trie state from external payload mutation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@ChangyiYang

Copy link
Copy Markdown
Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a prefix-trie-backed gateway session storage (PrefixTrie) to support multi-trajectory sessions, including multi-turn conversations, best-of-N trajectories, and sub-agents, controlled by a new gateway_trie_enabled configuration flag. It also adds comprehensive unit and integration tests to verify the trie's structural behavior and parity with the legacy linear trajectory flow. The review feedback highlights a critical issue where request cancellations (raising BaseException) could leak pending trie nodes, and suggests moving an inline json import to the top of the file to adhere to PEP 8 guidelines.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 190 to 195
except ValueError as e:
self._abandon_pending(encoded)
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
self._abandon_pending(encoded)
raise HTTPException(status_code=500, detail=f"{e.__class__.__name__}: {e}") from e

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the generation request is cancelled during await backend.generate(...), an asyncio.CancelledError (which inherits from BaseException, not Exception) will be raised. Since the current except blocks only catch ValueError and Exception, a cancellation will bypass them, leaking the pending/in-flight node in the trie. Adding an except BaseException: block ensures that _abandon_pending is always called to clean up the trie state upon cancellation.

Suggested change
except ValueError as e:
self._abandon_pending(encoded)
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
self._abandon_pending(encoded)
raise HTTPException(status_code=500, detail=f"{e.__class__.__name__}: {e}") from e
except ValueError as e:
self._abandon_pending(encoded)
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
self._abandon_pending(encoded)
raise HTTPException(status_code=500, detail=f"{e.__class__.__name__}: {e}") from e
except BaseException:
self._abandon_pending(encoded)
raise

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in 7f9f42e. Added an except BaseException: self._abandon_pending(encoded); raise after the ValueError/Exception handlers around backend.generate, so a cancellation (CancelledError) also releases the pending node. Added test_trie_abandons_pending_node_on_cancellation (backend raises CancelledError, asserts num_inflight()==0); verified it fails without the guard.

Comment on lines +25 to +30
from __future__ import annotations

import hashlib
import uuid
from dataclasses import dataclass, field
from typing import Any, Iterator

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. Let's import json here at the top level.

Suggested change
from __future__ import annotations
import hashlib
import uuid
from dataclasses import dataclass, field
from typing import Any, Iterator
from __future__ import annotations
import hashlib
import json
import uuid
from dataclasses import dataclass, field
from typing import Any, Iterator
References
  1. PEP 8: Imports should always be put at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 7f9f42eimport json moved to the module top.

Comment on lines +48 to +54
if isinstance(arguments, str):
import json

try:
return ("json", _freeze(json.loads(arguments)))
except json.JSONDecodeError:
return ("raw", arguments)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove the inline import json statement now that it has been moved to the top of the file, adhering to PEP 8 guidelines.

Suggested change
if isinstance(arguments, str):
import json
try:
return ("json", _freeze(json.loads(arguments)))
except json.JSONDecodeError:
return ("raw", arguments)
if isinstance(arguments, str):
try:
return ("json", _freeze(json.loads(arguments)))
except json.JSONDecodeError:
return ("raw", arguments)
References
  1. PEP 8: Imports should always be put at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 7f9f42e — inline import removed, uses the top-level json.

- session.py: catch BaseException around backend.generate so a cancellation
  (asyncio.CancelledError, which is not an Exception) still abandons the trie
  pending node before propagating. Added a regression test that fails without
  the guard.
- trie.py: move 'import json' to module top (PEP 8) instead of inline.

Co-Authored-By: Changyi Yang <changyiyang2023@gmail.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@ChangyiYang

Copy link
Copy Markdown
Author

@gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a prefix trie (PrefixTrie) for multi-trajectory session storage in the gateway, enabling multi-branch reattachment, best-of-N, sub-agent, and warm-start capabilities behind a gateway_trie_enabled flag. Feedback focuses on resolving a discrepancy in the num_turns calculation for length-exhausted trajectories between the legacy and trie flows. To achieve exact parity, the reviewer suggests tracking checkpoint_messages in EncodedData and passing it to the commit method during length exhaustion.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +67 to +69
branch_handle: BranchHandle | None = None
new_image_data: list[Any] | None = None
new_video_data: list[Any] | None = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To support exact parity in num_turns for length-exhausted trajectories, we need to track the checkpoint messages (the history before the current turn) in EncodedData so they can be passed to commit when length-exhausted.

Suggested change
branch_handle: BranchHandle | None = None
new_image_data: list[Any] | None = None
new_video_data: list[Any] | None = None
branch_handle: BranchHandle | None = None
new_image_data: list[Any] | None = None
new_video_data: list[Any] | None = None
checkpoint_messages: list[dict[str, Any]] | None = None

Comment on lines +156 to +166
if self._trie_enabled:
self._trie.commit(
encoded.branch_handle,
encoded.buffer,
empty_msg,
request_tools=encoded.tools,
messages=encoded.messages,
image_data=encoded.new_image_data,
video_data=encoded.new_video_data,
extra_fields={"finish_reason": "length"},
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a discrepancy in num_turns calculation for length-exhausted trajectories between the legacy and trie flows. In the legacy flow, _build_materialized_trajectory is called before self.message_history is updated, so num_turns is based on the history before the current turn. In the trie flow, the trajectory is materialized during finalize using checkpoint.messages which includes the current turn's messages plus the empty assistant message, leading to a discrepancy of 2 turns. To achieve exact parity, we should pass encoded.checkpoint_messages (the history before the current turn) to commit when length-exhausted.

Suggested change
if self._trie_enabled:
self._trie.commit(
encoded.branch_handle,
encoded.buffer,
empty_msg,
request_tools=encoded.tools,
messages=encoded.messages,
image_data=encoded.new_image_data,
video_data=encoded.new_video_data,
extra_fields={"finish_reason": "length"},
)
if self._trie_enabled:
self._trie.commit(
encoded.branch_handle,
encoded.buffer,
empty_msg,
request_tools=encoded.tools,
messages=encoded.checkpoint_messages,
image_data=encoded.new_image_data,
video_data=encoded.new_video_data,
extra_fields={"finish_reason": "length"},
)

Comment on lines +424 to +430
length_exhausted_trajectory=self._trajectory_from_buffer(
buffer, messages, extra_fields={"finish_reason": "length"}
),
branch_handle=prepared.branch_handle,
new_image_data=new_image_data,
new_video_data=new_video_data,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Populate the checkpoint_messages field in EncodedData during early length exhaustion in trie mode.

                        length_exhausted_trajectory=self._trajectory_from_buffer(
                            buffer, messages, extra_fields={"finish_reason": "length"}
                        ),
                        branch_handle=prepared.branch_handle,
                        new_image_data=new_image_data,
                        new_video_data=new_video_data,
                        checkpoint_messages=prepared.checkpoint_messages,

Comment on lines +453 to +456
branch_handle=prepared.branch_handle,
new_image_data=new_image_data,
new_video_data=new_video_data,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Populate the checkpoint_messages field in EncodedData during normal generation prep in trie mode.

            branch_handle=prepared.branch_handle,
            new_image_data=new_image_data,
            new_video_data=new_video_data,
            checkpoint_messages=prepared.checkpoint_messages,

@gxlvera

gxlvera commented Jun 18, 2026

Copy link
Copy Markdown

nice work bro

@gxlvera

gxlvera commented Jun 18, 2026

Copy link
Copy Markdown

With prefix trie, I think maybe we can get rid of generation lock but only use request lock? the reason is stated in details in: #51 as

"One shared state requires a generation lock and serial LLM calls. With a trie, each call owns a cloned branch state; tokenize and commit can interleave—supporting sub-agents, best-of-n, etc.

@gxlvera

gxlvera commented Jun 18, 2026

Copy link
Copy Markdown

a complete pr is here:#65, this one will be closed soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants