[gateway] Prefix Trie for multi-trajectory storage (M1)#64
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces the gateway session's single linear active trajectory with a prefix trie (PrefixTrie) to natively support longest-prefix reattachment across multiple branches, moving TrajectoryBuffer to types.py to prevent circular imports and adding comprehensive CPU-only unit tests. The review feedback highlights several critical improvements in uni_agent/gateway/trie.py, including returning deep copies of messages in rebuild_messages to prevent internal state mutation, hashing large string payloads in _media_digest to avoid performance issues, removing the unused _PROMPT_SIDE_ROLES constant, ensuring stale multimodal data is cleared during node updates in upsert_assistant, and optimizing the trie traversal in iter_export_nodes from
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def rebuild_messages(self, node: TrieNode) -> list[dict[str, Any]]: | ||
| """Rebuild the message transcript from the root to ``node`` (inclusive).""" | ||
| return [n.message for n in self._path_to_root(node) if n.message is not None] |
There was a problem hiding this comment.
rebuild_messages returns references to the original message dictionaries stored in the trie nodes. If the caller mutates these dictionaries (e.g., during normalization or template rendering), it will corrupt the trie's internal state and potentially break prefix matching (since MessageKey is built from the message content).
To prevent this, we should return deep copies of the messages.
| def rebuild_messages(self, node: TrieNode) -> list[dict[str, Any]]: | |
| """Rebuild the message transcript from the root to ``node`` (inclusive).""" | |
| return [n.message for n in self._path_to_root(node) if n.message is not None] | |
| def rebuild_messages(self, node: TrieNode) -> list[dict[str, Any]]: | |
| """Rebuild the message transcript from the root to ``node`` (inclusive).""" | |
| import copy | |
| return [copy.deepcopy(n.message) for n in self._path_to_root(node) if n.message is not None] |
There was a problem hiding this comment.
Good catch — fixed in d568b1b. Returning shallow copies (dict(n.message)) so a caller re-keying top-level fields can't mutate the stored node. Used shallow rather than deepcopy on purpose: the messages can carry large inline multimodal payloads and deep-copying them on every rebuild_messages would be costly, while current callers only read checkpoint_messages for length-slicing and num_turns, never mutating nested content. Note the child-dict keys are pre-computed frozen MessageKeys, so prefix matching itself isn't affected by node mutation; the real risk was transcript/checkpoint corruption, which the shallow copy guards against.
| if isinstance(value, str): | ||
| return value |
There was a problem hiding this comment.
In _media_digest, returning value verbatim when it is a str can lead to high memory usage and slow comparisons if the string is a very large base64-encoded data URI (which is common for inline multimodal payloads).
We should hash the string if it is a data URI or exceeds a reasonable length (e.g., 256 characters) to keep the MessageKey compact and uniform.
| if isinstance(value, str): | |
| return value | |
| if isinstance(value, str): | |
| if value.startswith("data:") or len(value) > 256: | |
| return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] | |
| return value |
There was a problem hiding this comment.
Done in d568b1b — _media_digest now hashes data: URIs and any string >256 chars (sha256, 16-hex), keeping short URLs/ids verbatim. Keeps the MessageKey compact and comparisons cheap for inline base64 payloads.
| # Roles whose messages are append-only prompt-side turns (never the generated | ||
| # assistant). Used to classify which incoming messages become pending nodes. | ||
| _PROMPT_SIDE_ROLES = frozenset({"system", "user", "tool"}) |
There was a problem hiding this comment.
Removed in d568b1b — it was a leftover from an earlier draft and is no longer referenced.
| if checkpoint.image_data: | ||
| child.image_data = list(checkpoint.image_data) | ||
| if checkpoint.video_data: | ||
| child.video_data = list(checkpoint.video_data) |
There was a problem hiding this comment.
In upsert_assistant, when updating an existing node (e.g., during an idempotent retry or refresh), if the new checkpoint has image_data or video_data as None, the existing child.image_data and child.video_data are not cleared. This leaves stale multimodal data on the node, causing an inconsistency between the node's fields and its checkpoint.
We should explicitly set these fields to None (or copy them) if they are not present in the new checkpoint.
| if checkpoint.image_data: | |
| child.image_data = list(checkpoint.image_data) | |
| if checkpoint.video_data: | |
| child.video_data = list(checkpoint.video_data) | |
| child.image_data = list(checkpoint.image_data) if checkpoint.image_data is not None else None | |
| child.video_data = list(checkpoint.video_data) if checkpoint.video_data is not None else None |
There was a problem hiding this comment.
Fixed in d568b1b — upsert_assistant now sets both fields unconditionally (list(...) if ... is not None else None), so a refresh/idempotent retry whose checkpoint has no multimodal data clears any stale payload and keeps the node consistent with its checkpoint.
| def iter_export_nodes(self, *, export_all: bool = False) -> Iterator[TrieNode]: | ||
| """Yield assistant checkpoint nodes for finalize. | ||
|
|
||
| By default emits only *terminal* assistant checkpoints (no committed | ||
| assistant descendant): rejected best-of-N siblings stay short leaves and | ||
| a continued branch is represented by its deepest checkpoint. With | ||
| ``export_all`` every committed checkpoint is emitted. | ||
| """ | ||
| for node in self._iter_nodes(self.root): | ||
| if node.checkpoint is None: | ||
| continue | ||
| if export_all or not self._has_committed_descendant(node): | ||
| yield node |
There was a problem hiding this comment.
The current implementation of iter_export_nodes performs a full trie traversal (_iter_nodes), and for every node, it calls _has_committed_descendant which performs another traversal of all its descendants. For a trie of depth
We can optimize this to
def iter_export_nodes(self, *, export_all: bool = False) -> Iterator[TrieNode]:
"""Yield assistant checkpoint nodes for finalize.
By default emits only *terminal* assistant checkpoints (no committed
assistant descendant): rejected best-of-N siblings stay short leaves and
a continued branch is represented by its deepest checkpoint. With
``export_all`` every committed checkpoint is emitted.
"""
if export_all:
for node in self._iter_nodes(self.root):
if node.checkpoint is not None:
yield node
return
def _find_terminals(node: TrieNode) -> tuple[bool, list[TrieNode]]:
has_ckpt = False
terminals = []
for child in node.children.values():
child_has_ckpt, child_terminals = _find_terminals(child)
has_ckpt = has_ckpt or child_has_ckpt
terminals.extend(child_terminals)
if node.checkpoint is not None:
if not has_ckpt:
terminals.append(node)
has_ckpt = True
return has_ckpt, terminals
_, terminal_nodes = _find_terminals(self.root)
yield from terminal_nodesThere was a problem hiding this comment.
Agreed — fixed in d568b1b. Replaced the per-node descendant scan with a single O(N) post-order pass (_collect_terminals) that returns, per subtree, whether it already carries a checkpoint, so an ancestor knows if it's the deepest. export_all keeps a straight checkpoint-node walk. Semantics unchanged; the 19 trie/session tests still pass.
|
@codex review |
5ff54a1 to
fd9db15
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a prefix trie implementation (uni_agent/gateway/trie.py) to support multi-trajectory session storage, enabling features like multi-branch reattachment, best-of-N, and sub-agents. It updates the gateway session, configuration, and types to integrate this trie-backed state model while maintaining backward compatibility with the legacy linear trajectory flow. Feedback on the changes highlights two key improvements: wrapping the trie preparation logic in a try...except BaseException block to prevent pending node leaks during exceptions or cancellations, and using _freeze(value) in _media_digest to ensure stable, order-canonical hashing of arbitrary fallback structures.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| prepared = self._trie.prepare(messages) | ||
| # Reuse the cloned checkpoint only when its tools match this request; | ||
| # a tools change forces a full re-encode (mirrors the legacy | ||
| # ``active_tool_schemas != tools`` gate). | ||
| use_incremental = prepared.trajectory_buffer is not None and prepared.request_tools == tools | ||
|
|
||
| # Multimodal introduced by *this* turn (messages beyond the checkpoint), | ||
| # stored on the committed node for per-branch reconstruction. | ||
| delta_messages = messages[len(prepared.checkpoint_messages) :] | ||
| new_image_data, new_video_data = (None, None) | ||
| if delta_messages: | ||
| new_image_data, new_video_data = await self._codec.extract_multi_modal_data(delta_messages) | ||
|
|
||
| if not use_incremental: | ||
| image_data, video_data = await self._codec.extract_multi_modal_data(messages) | ||
| prompt_ids = self._codec.encode_full( | ||
| messages, | ||
| tools=tools, | ||
| image_data=image_data, | ||
| video_data=video_data, | ||
| request_chat_template_kwargs=request_chat_template_kwargs, | ||
| ) | ||
| buffer = TrajectoryBuffer(prompt_ids=prompt_ids) | ||
| # Whole prompt is freshly encoded; the node owns all of its images. | ||
| new_image_data, new_video_data = image_data, video_data | ||
| else: | ||
| buffer = prepared.trajectory_buffer | ||
| image_data = list(prepared.image_data) if prepared.image_data is not None else None | ||
| video_data = list(prepared.video_data) if prepared.video_data is not None else None | ||
| if delta_messages: | ||
| if new_image_data: | ||
| image_data = (image_data or []) + new_image_data | ||
| if new_video_data: | ||
| video_data = (video_data or []) + new_video_data | ||
| incremental_ids = self._codec.encode_incremental( | ||
| delta_messages, | ||
| image_data=new_image_data, | ||
| video_data=new_video_data, | ||
| request_chat_template_kwargs=request_chat_template_kwargs, | ||
| ) | ||
| if ( | ||
| self._response_length is not None | ||
| and len(buffer.response_mask) + len(incremental_ids) >= self._response_length | ||
| ): | ||
| context_ids = buffer.prompt_ids + buffer.response_ids | ||
| return EncodedData( | ||
| buffer=buffer, | ||
| context_ids=context_ids, | ||
| sampling_params={}, | ||
| messages=list(messages), | ||
| tools=tools, | ||
| image_data=image_data, | ||
| video_data=video_data, | ||
| materialized_trajectory=None, | ||
| length_exhausted_trajectory=self._trajectory_from_buffer( | ||
| buffer, messages, extra_fields={"finish_reason": "length"} | ||
| ), | ||
| branch_handle=prepared.branch_handle, | ||
| new_image_data=new_image_data, | ||
| new_video_data=new_video_data, | ||
| ) | ||
| buffer.response_ids.extend(incremental_ids) | ||
| buffer.response_mask.extend([0] * len(incremental_ids)) | ||
| if buffer.response_logprobs: | ||
| buffer.response_logprobs.extend([0.0] * len(incremental_ids)) | ||
|
|
||
| context_ids = buffer.prompt_ids + buffer.response_ids | ||
| sampling_params = self._codec.build_sampling_params(payload) | ||
| remaining_response_budget = ( | ||
| self._response_length - len(buffer.response_mask) if self._response_length is not None else None | ||
| ) | ||
| if remaining_response_budget is not None and "max_tokens" in sampling_params: | ||
| sampling_params["max_tokens"] = min(sampling_params["max_tokens"], remaining_response_budget) | ||
| return EncodedData( | ||
| buffer=buffer, | ||
| context_ids=context_ids, | ||
| sampling_params=sampling_params, | ||
| messages=list(messages), | ||
| tools=tools, | ||
| image_data=image_data, | ||
| video_data=video_data, | ||
| materialized_trajectory=None, | ||
| length_exhausted_trajectory=None, | ||
| branch_handle=prepared.branch_handle, | ||
| new_image_data=new_image_data, | ||
| new_video_data=new_video_data, | ||
| ) |
There was a problem hiding this comment.
If any exception or cancellation occurs during incremental encoding or multimodal extraction inside _prepare_generation_inputs_trie after self._trie.prepare has been called, the pending node will be leaked. Wrapping the preparation logic in a try...except BaseException: block ensures that the pending node is properly abandoned on any failure or cancellation.
prepared = self._trie.prepare(messages)
try:
# Reuse the cloned checkpoint only when its tools match this request;
# a tools change forces a full re-encode (mirrors the legacy
# ``active_tool_schemas != tools`` gate).
use_incremental = prepared.trajectory_buffer is not None and prepared.request_tools == tools
# Multimodal introduced by *this* turn (messages beyond the checkpoint),
# stored on the committed node for per-branch reconstruction.
delta_messages = messages[len(prepared.checkpoint_messages) :]
new_image_data, new_video_data = (None, None)
if delta_messages:
new_image_data, new_video_data = await self._codec.extract_multi_modal_data(delta_messages)
if not use_incremental:
image_data, video_data = await self._codec.extract_multi_modal_data(messages)
prompt_ids = self._codec.encode_full(
messages,
tools=tools,
image_data=image_data,
video_data=video_data,
request_chat_template_kwargs=request_chat_template_kwargs,
)
buffer = TrajectoryBuffer(prompt_ids=prompt_ids)
# Whole prompt is freshly encoded; the node owns all of its images.
new_image_data, new_video_data = image_data, video_data
else:
buffer = prepared.trajectory_buffer
image_data = list(prepared.image_data) if prepared.image_data is not None else None
video_data = list(prepared.video_data) if prepared.video_data is not None else None
if delta_messages:
if new_image_data:
image_data = (image_data or []) + new_image_data
if new_video_data:
video_data = (video_data or []) + new_video_data
incremental_ids = self._codec.encode_incremental(
delta_messages,
image_data=new_image_data,
video_data=new_video_data,
request_chat_template_kwargs=request_chat_template_kwargs,
)
if (
self._response_length is not None
and len(buffer.response_mask) + len(incremental_ids) >= self._response_length
):
context_ids = buffer.prompt_ids + buffer.response_ids
return EncodedData(
buffer=buffer,
context_ids=context_ids,
sampling_params={},
messages=list(messages),
tools=tools,
image_data=image_data,
video_data=video_data,
materialized_trajectory=None,
length_exhausted_trajectory=self._trajectory_from_buffer(
buffer, messages, extra_fields={"finish_reason": "length"}
),
branch_handle=prepared.branch_handle,
new_image_data=new_image_data,
new_video_data=new_video_data,
)
buffer.response_ids.extend(incremental_ids)
buffer.response_mask.extend([0] * len(incremental_ids))
if buffer.response_logprobs:
buffer.response_logprobs.extend([0.0] * len(incremental_ids))
context_ids = buffer.prompt_ids + buffer.response_ids
sampling_params = self._codec.build_sampling_params(payload)
remaining_response_budget = (
self._response_length - len(buffer.response_mask) if self._response_length is not None else None
)
if remaining_response_budget is not None and "max_tokens" in sampling_params:
sampling_params["max_tokens"] = min(sampling_params["max_tokens"], remaining_response_budget)
return EncodedData(
buffer=buffer,
context_ids=context_ids,
sampling_params=sampling_params,
messages=list(messages),
tools=tools,
image_data=image_data,
video_data=video_data,
materialized_trajectory=None,
length_exhausted_trajectory=None,
branch_handle=prepared.branch_handle,
new_image_data=new_image_data,
new_video_data=new_video_data,
)
except BaseException:
self._trie.abandon(prepared.branch_handle)
raiseThere was a problem hiding this comment.
Good catch — fixed in 54c62f4. _prepare_generation_inputs_trie now wraps the post-prepare encoding in try/except BaseException and calls trie.abandon(prepared.branch_handle) before re-raising, so a failure or cancellation during extract_multi_modal_data/encode_incremental no longer leaks the pending node. Added test_trie_abandons_pending_node_when_encode_fails (forces encode_full to raise, asserts num_inflight()==0); verified it fails without the guard.
| return value | ||
| if isinstance(value, (bytes, bytearray)): | ||
| return hashlib.sha256(bytes(value)).hexdigest()[:16] | ||
| return hashlib.sha256(repr(value).encode("utf-8")).hexdigest()[:16] |
There was a problem hiding this comment.
Using repr(value) directly on arbitrary fallback structures (such as lists of dictionaries) is unstable because dictionary key ordering in repr() is not guaranteed to be canonical across different Python processes or runs. Wrapping value in _freeze(value) recursively normalizes any nested dictionaries and lists into sorted, hashable tuples, ensuring a stable and order-canonical representation.
| return hashlib.sha256(repr(value).encode("utf-8")).hexdigest()[:16] | |
| return hashlib.sha256(repr(_freeze(value)).encode("utf-8")).hexdigest()[:16] |
There was a problem hiding this comment.
Done in 54c62f4 — the scalar fallback now hashes repr(_freeze(value)) so nested dict/list ordering is canonicalized before hashing.
Add a per-session prefix trie to replace the single linear active trajectory (issue verl-project#51): every committed assistant turn becomes a node that may carry a checkpoint, and an incoming request longest-prefix matches any path and continues from the nearest checkpoint. - session/trie.py: MessageKey (+canonicalize_content/make_message_key), BranchCheckpoint, TrieNode, PrefixTrie (prepare/commit/finalize-traversal). MessageKey excludes reasoning_content (M1) and hashes multimodal content via a stable image/video digest (data:/long strings are hashed for compactness). - session/types.py: home TrajectoryBuffer here so trie stays free of the verl-importing codec; re-exported via session/__init__. - CPU tests cover the RFC appendix fork types (sequential / system-split / condensation / best-of-N + idempotent retry / warm-start) plus pending-node, clone-isolation, export, multimodal-digest, and regression guards that fail on the pre-review behavior. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Wire the trie into GatewaySession end-to-end, gated by GatewayActorConfig.gateway_trie_enabled (default False; flag-off reproduces the legacy single-active behavior exactly, covered by a parity test). - session.py: trie-mode prepare routes through trie.prepare (longest-prefix match + nearest-checkpoint clone) and reuses the existing encode_full/encode_incremental paths; commit attaches the assistant child; finalize traverses the trie -> one branch-aware trajectory per terminal checkpoint. Per-branch multimodal stored on the node; failed generations abandon the pending node; snapshot_state reports num_branches/num_inflight. tools-change gate mirrors the legacy active_tool_schemas gate. - config.py/gateway.py: thread gateway_trie_enabled into sessions. - session-integration tests: flag parity, best-of-N fan-out, multi-turn reattach. Tokenization stays on remove_system_prompt (TITO -> M2); generation_lock kept (concurrency -> M3); no GC (trie dies with the session). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
fd9db15 to
7d40dc9
Compare
- _prepare_generation_inputs_trie: abandon the pending node if encoding fails or is cancelled after trie.prepare (try/except BaseException around the encode step), preventing in-flight bookkeeping leaks. Added a session test that fails without the guard. - _media_digest: hash repr(_freeze(value)) in the scalar fallback so nested dict/list ordering is canonical across processes. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a prefix trie (PrefixTrie) for multi-trajectory session storage, enabling advanced capabilities like multi-branch reattachment, best-of-N trajectories, and sub-agent workflows under a new gateway_trie_enabled configuration flag. The feedback focuses on enhancing robustness by defensively copying message dictionaries when storing them in trie nodes and checkpoints to prevent external mutations from corrupting the trie's internal state. Additionally, a high-severity bug was identified where a full encode path incorrectly overwrites the turn-level multimodal delta lists with full historical lists, which would cause duplicate entries during path reconstruction.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| # Whole prompt is freshly encoded; the node owns all of its images. | ||
| new_image_data, new_video_data = image_data, video_data |
There was a problem hiding this comment.
When use_incremental is False (e.g., due to a tools change), a full encode is performed. In this case, image_data and video_data contain the multimodal payloads for the entire conversation history. Overwriting new_image_data and new_video_data with these full lists causes the committed node to store all historical images/videos. When collect_multi_modal subsequently traverses the path to reconstruct the multimodal sequence, it will collect these images/videos from both the current node and its ancestors, leading to duplicate entries.
To prevent duplication, the node should only store the delta images/videos introduced in the current turn (which are already extracted into new_image_data and new_video_data from delta_messages). Removing the overwrite ensures that only the delta is stored on the node, while the backend still receives the full lists for the full encode.
| # Whole prompt is freshly encoded; the node owns all of its images. | |
| new_image_data, new_video_data = image_data, video_data | |
| # Whole prompt is freshly encoded; the node only stores the delta images | |
| # to avoid duplicates during path reconstruction. |
There was a problem hiding this comment.
Good catch (real bug) — fixed in f4294f1. On a full re-encode mid-branch the node now stores only this turn's delta media (new_image_data/new_video_data already extracted from delta_messages); the backend still gets the full lists. Removed the delta→full overwrite. Added test_trie_no_duplicate_multimodal_on_full_encode_midbranch (image turn 1, then a tools change + new image on turn 2 forcing full re-encode) and verified it fails without the fix (yields the duplicated ancestor image).
| key = make_message_key(message) | ||
| child = attach.children.get(key) | ||
| if child is None: | ||
| child = TrieNode(key=key, message=message, parent=attach) |
There was a problem hiding this comment.
To prevent external mutations of the request payload from accidentally corrupting the trie's internal state, it is highly recommended to perform a defensive copy of the message dictionary when storing it in the TrieNode.
| child = TrieNode(key=key, message=message, parent=attach) | |
| child = TrieNode(key=key, message=dict(message), parent=attach) |
There was a problem hiding this comment.
Done in f4294f1 — dict(message) shallow copy when creating the node.
| child = TrieNode(key=key, message=assistant_msg, parent=parent) | ||
| parent.children[key] = child | ||
| else: | ||
| child.message = assistant_msg |
There was a problem hiding this comment.
Similarly, perform a defensive copy of assistant_msg when creating or updating an assistant node to isolate the trie's internal state from any subsequent modifications to the assistant message dictionary.
| child = TrieNode(key=key, message=assistant_msg, parent=parent) | |
| parent.children[key] = child | |
| else: | |
| child.message = assistant_msg | |
| if child is None: | |
| child = TrieNode(key=key, message=dict(assistant_msg), parent=parent) | |
| parent.children[key] = child | |
| else: | |
| child.message = dict(assistant_msg) |
There was a problem hiding this comment.
Done in f4294f1 — dict(assistant_msg) on both create and refresh.
| # request prompt PLUS the generated assistant turn, so the stored prefix | ||
| # must include ``assistant_msg`` (this is what the next turn's | ||
| # ``checkpoint_messages`` slices against). | ||
| covered_messages = list(messages) + [assistant_msg] if messages is not None else None |
There was a problem hiding this comment.
Defensively copy the messages in covered_messages to ensure that the checkpoint's stored message history is completely isolated from any external mutations.
| covered_messages = list(messages) + [assistant_msg] if messages is not None else None | |
| covered_messages = [dict(m) for m in messages] + [dict(assistant_msg)] if messages is not None else None |
There was a problem hiding this comment.
Done in f4294f1 — covered_messages now stores [dict(m) for m in messages] + [dict(assistant_msg)].
- session.py: on a full re-encode mid-branch (e.g. tools change), store only this turn's delta media on the committed node instead of the whole prompt's media, so collect_multi_modal does not double-count media carried on ancestor checkpoints. Added a regression test (multimodal + tools-change mid-branch) that fails without the fix. - trie.py: defensively shallow-copy message dicts when storing them in nodes (materialize_prompt_suffix, upsert_assistant) and in the checkpoint's covered_messages, isolating trie state from external payload mutation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a prefix-trie-backed gateway session storage (PrefixTrie) to support multi-trajectory sessions, including multi-turn conversations, best-of-N trajectories, and sub-agents, controlled by a new gateway_trie_enabled configuration flag. It also adds comprehensive unit and integration tests to verify the trie's structural behavior and parity with the legacy linear trajectory flow. The review feedback highlights a critical issue where request cancellations (raising BaseException) could leak pending trie nodes, and suggests moving an inline json import to the top of the file to adhere to PEP 8 guidelines.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| except ValueError as e: | ||
| self._abandon_pending(encoded) | ||
| raise HTTPException(status_code=400, detail=str(e)) from e | ||
| except Exception as e: | ||
| self._abandon_pending(encoded) | ||
| raise HTTPException(status_code=500, detail=f"{e.__class__.__name__}: {e}") from e |
There was a problem hiding this comment.
If the generation request is cancelled during await backend.generate(...), an asyncio.CancelledError (which inherits from BaseException, not Exception) will be raised. Since the current except blocks only catch ValueError and Exception, a cancellation will bypass them, leaking the pending/in-flight node in the trie. Adding an except BaseException: block ensures that _abandon_pending is always called to clean up the trie state upon cancellation.
| except ValueError as e: | |
| self._abandon_pending(encoded) | |
| raise HTTPException(status_code=400, detail=str(e)) from e | |
| except Exception as e: | |
| self._abandon_pending(encoded) | |
| raise HTTPException(status_code=500, detail=f"{e.__class__.__name__}: {e}") from e | |
| except ValueError as e: | |
| self._abandon_pending(encoded) | |
| raise HTTPException(status_code=400, detail=str(e)) from e | |
| except Exception as e: | |
| self._abandon_pending(encoded) | |
| raise HTTPException(status_code=500, detail=f"{e.__class__.__name__}: {e}") from e | |
| except BaseException: | |
| self._abandon_pending(encoded) | |
| raise |
There was a problem hiding this comment.
Good catch — fixed in 7f9f42e. Added an except BaseException: self._abandon_pending(encoded); raise after the ValueError/Exception handlers around backend.generate, so a cancellation (CancelledError) also releases the pending node. Added test_trie_abandons_pending_node_on_cancellation (backend raises CancelledError, asserts num_inflight()==0); verified it fails without the guard.
| from __future__ import annotations | ||
|
|
||
| import hashlib | ||
| import uuid | ||
| from dataclasses import dataclass, field | ||
| from typing import Any, Iterator |
There was a problem hiding this comment.
According to PEP 8, imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. Let's import json here at the top level.
| from __future__ import annotations | |
| import hashlib | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from typing import Any, Iterator | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from typing import Any, Iterator |
References
- PEP 8: Imports should always be put at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)
There was a problem hiding this comment.
Done in 7f9f42e — import json moved to the module top.
| if isinstance(arguments, str): | ||
| import json | ||
|
|
||
| try: | ||
| return ("json", _freeze(json.loads(arguments))) | ||
| except json.JSONDecodeError: | ||
| return ("raw", arguments) |
There was a problem hiding this comment.
Remove the inline import json statement now that it has been moved to the top of the file, adhering to PEP 8 guidelines.
| if isinstance(arguments, str): | |
| import json | |
| try: | |
| return ("json", _freeze(json.loads(arguments))) | |
| except json.JSONDecodeError: | |
| return ("raw", arguments) | |
| if isinstance(arguments, str): | |
| try: | |
| return ("json", _freeze(json.loads(arguments))) | |
| except json.JSONDecodeError: | |
| return ("raw", arguments) |
References
- PEP 8: Imports should always be put at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)
There was a problem hiding this comment.
Done in 7f9f42e — inline import removed, uses the top-level json.
- session.py: catch BaseException around backend.generate so a cancellation (asyncio.CancelledError, which is not an Exception) still abandons the trie pending node before propagating. Added a regression test that fails without the guard. - trie.py: move 'import json' to module top (PEP 8) instead of inline. Co-Authored-By: Changyi Yang <changyiyang2023@gmail.com> Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
@gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a prefix trie (PrefixTrie) for multi-trajectory session storage in the gateway, enabling multi-branch reattachment, best-of-N, sub-agent, and warm-start capabilities behind a gateway_trie_enabled flag. Feedback focuses on resolving a discrepancy in the num_turns calculation for length-exhausted trajectories between the legacy and trie flows. To achieve exact parity, the reviewer suggests tracking checkpoint_messages in EncodedData and passing it to the commit method during length exhaustion.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| branch_handle: BranchHandle | None = None | ||
| new_image_data: list[Any] | None = None | ||
| new_video_data: list[Any] | None = None |
There was a problem hiding this comment.
To support exact parity in num_turns for length-exhausted trajectories, we need to track the checkpoint messages (the history before the current turn) in EncodedData so they can be passed to commit when length-exhausted.
| branch_handle: BranchHandle | None = None | |
| new_image_data: list[Any] | None = None | |
| new_video_data: list[Any] | None = None | |
| branch_handle: BranchHandle | None = None | |
| new_image_data: list[Any] | None = None | |
| new_video_data: list[Any] | None = None | |
| checkpoint_messages: list[dict[str, Any]] | None = None |
| if self._trie_enabled: | ||
| self._trie.commit( | ||
| encoded.branch_handle, | ||
| encoded.buffer, | ||
| empty_msg, | ||
| request_tools=encoded.tools, | ||
| messages=encoded.messages, | ||
| image_data=encoded.new_image_data, | ||
| video_data=encoded.new_video_data, | ||
| extra_fields={"finish_reason": "length"}, | ||
| ) |
There was a problem hiding this comment.
There is a discrepancy in num_turns calculation for length-exhausted trajectories between the legacy and trie flows. In the legacy flow, _build_materialized_trajectory is called before self.message_history is updated, so num_turns is based on the history before the current turn. In the trie flow, the trajectory is materialized during finalize using checkpoint.messages which includes the current turn's messages plus the empty assistant message, leading to a discrepancy of 2 turns. To achieve exact parity, we should pass encoded.checkpoint_messages (the history before the current turn) to commit when length-exhausted.
| if self._trie_enabled: | |
| self._trie.commit( | |
| encoded.branch_handle, | |
| encoded.buffer, | |
| empty_msg, | |
| request_tools=encoded.tools, | |
| messages=encoded.messages, | |
| image_data=encoded.new_image_data, | |
| video_data=encoded.new_video_data, | |
| extra_fields={"finish_reason": "length"}, | |
| ) | |
| if self._trie_enabled: | |
| self._trie.commit( | |
| encoded.branch_handle, | |
| encoded.buffer, | |
| empty_msg, | |
| request_tools=encoded.tools, | |
| messages=encoded.checkpoint_messages, | |
| image_data=encoded.new_image_data, | |
| video_data=encoded.new_video_data, | |
| extra_fields={"finish_reason": "length"}, | |
| ) |
| length_exhausted_trajectory=self._trajectory_from_buffer( | ||
| buffer, messages, extra_fields={"finish_reason": "length"} | ||
| ), | ||
| branch_handle=prepared.branch_handle, | ||
| new_image_data=new_image_data, | ||
| new_video_data=new_video_data, | ||
| ) |
There was a problem hiding this comment.
Populate the checkpoint_messages field in EncodedData during early length exhaustion in trie mode.
length_exhausted_trajectory=self._trajectory_from_buffer(
buffer, messages, extra_fields={"finish_reason": "length"}
),
branch_handle=prepared.branch_handle,
new_image_data=new_image_data,
new_video_data=new_video_data,
checkpoint_messages=prepared.checkpoint_messages,| branch_handle=prepared.branch_handle, | ||
| new_image_data=new_image_data, | ||
| new_video_data=new_video_data, | ||
| ) |
|
nice work bro |
|
With prefix trie, I think maybe we can get rid of generation lock but only use request lock? the reason is stated in details in: #51 as |
|
a complete pr is here:#65, this one will be closed soon |
Summary
Implements the per-session prefix trie that replaces the single linear
active_trajectory, per RFC #51. Today a session keeps onemessage_history+ one active trajectory, so only the latest branch can be re-attached; this generalizes it into a trie where an incoming request longest-prefix matches any path and continues from the nearest checkpoint — enabling sub-agent / best-of-N / context-condensation / warm-start.Fully gated behind
GatewayActorConfig.gateway_trie_enabled(default off); flag-off reproduces the legacy single-active behavior exactly (covered by a parity test).What's in this PR (full M1)
Data model —
gateway/trie.py(new)MessageKey(+canonicalize_content/make_message_key): hashable per-message identity; excludesreasoning_content(M1; TODO thinking-model replay) and keys multimodal content via a stable image/video digest (pixels live on the node, not the key).BranchCheckpoint,TrieNode,PrefixTriewith theprepare → commitlifecycle and afinalizetraversal (iter_export_nodes→ terminal checkpoints only by default).preparematerializes incoming prompt-side messages as pending/structural nodes (no checkpoint, never exported) and clones the nearest checkpoint into a request-local buffer.TrajectoryBufferre-homed togateway/types.pyso the trie stays free of the verl-importing codec.Session integration —
gateway/session.py_prepare_generation_inputs_trie: routes throughtrie.prepare(longest-prefix match + nearest-checkpoint clone) and reuses the existingencode_full/encode_incrementalpaths — only the state model changes.trie.commit;finalizetraverses the trie → one branch-awareTrajectoryper terminal checkpoint (reward stamped per the current one-session-one-reward contract).abandonthe pending node.snapshot_statereportsnum_branches/num_inflight_generationsin trie mode.active_tool_schemasgate.Wiring:
config.py/gateway.pythreadgateway_trie_enabledinto sessions.Tests (CPU-only)
tests/uni_agent/gateway/test_trie_on_cpu.py— 16 trie-structural tests covering the RFC appendix fork types: A sequential extension, B system-keyed split (parallel sub-agents), C context condensation, D best-of-N + idempotent retry, E warm-start; plus pending-node/no-export, clone isolation, terminal-vs-all export, multimodal digest + per-branch reconstruction.tests/uni_agent/gateway/test_session_trie_on_cpu.py— 3 session-integration tests: flag parity (trie-on yields identical backend prompts + identical finalized trajectory as trie-off on a linear multi-turn conversation — the compatibility gate), best-of-N fan-out (N siblings → N trajectories), multi-turn reattach (single continuous branch).Scope / deferrals
remove_system_prompt(text via tokenizer, multimodal via processor). Per-model seam handling (Qwen trailing\n, GLM 4.7 ambiguous bos/eos) + a token-sequence verifier → M2 (TITO-style merge).generation_lockkept (serial) → M3.