Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 48 additions & 23 deletions amplifier_foundation/session/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

class IncompleteTurn(TypedDict):
after_line: int
after_index: int
missing: str


Expand Down Expand Up @@ -84,7 +85,9 @@ def build_tool_index(entries: list[dict]) -> dict:
if entry.get("role") == "assistant" and "tool_calls" in entry:
for tool_call in entry["tool_calls"]:
call_id = tool_call.get("id", "")
tool_name = tool_call.get("function", {}).get("name", "")
tool_name = tool_call.get(
"tool", tool_call.get("function", {}).get("name", "")
)
tool_uses[call_id] = {
"line_num": entry.get("line_num"),
"tool_name": tool_name,
Expand Down Expand Up @@ -176,9 +179,11 @@ def diagnose_transcript(entries: list[dict]) -> DiagnosisResult:
if tc_id != first_tc_id:
continue

# Skip if any tool_call from this assistant message is orphaned/misplaced
# Skip if ALL tool_calls from this message are orphaned/misplaced
# (step 5 of repair handles those). For partial orphans, continue
# to FM3 detection so the incomplete turn is caught in one pass.
all_tc_ids = [tc["id"] for tc in assistant_entry.get("tool_calls", [])]
if any(tid in orphaned_set or tid in misplaced_set for tid in all_tc_ids):
if all(tid in orphaned_set or tid in misplaced_set for tid in all_tc_ids):
continue

# Find the last tool_result for this assistant message's tool_calls
Expand All @@ -199,6 +204,7 @@ def diagnose_transcript(entries: list[dict]) -> DiagnosisResult:
incomplete_turns.append(
{
"after_line": entries[last_result_idx].get("line_num"),
"after_index": last_result_idx,
"missing": "assistant_response",
}
)
Expand All @@ -210,6 +216,7 @@ def diagnose_transcript(entries: list[dict]) -> DiagnosisResult:
incomplete_turns.append(
{
"after_line": entries[last_result_idx].get("line_num"),
"after_index": last_result_idx,
"missing": "assistant_response",
}
)
Expand Down Expand Up @@ -283,7 +290,14 @@ def repair_transcript(entries: list[dict], diagnosis: DiagnosisResult) -> list[d
orphaned_set = set(diagnosis["orphaned_tool_ids"])
misplaced_set = set(diagnosis["misplaced_tool_ids"])
broken_set = orphaned_set | misplaced_set
incomplete_after_lines = {t["after_line"] for t in diagnosis["incomplete_turns"]}
# Use entry indices (stable) for matching, with line_num fallback for
# diagnosis results produced before after_index was added.
incomplete_after_indices = {
t["after_index"] for t in diagnosis["incomplete_turns"] if "after_index" in t
}
incomplete_after_lines = {
t["after_line"] for t in diagnosis["incomplete_turns"] if "after_index" not in t
}

# 2. Build skip_indices — entry indices of misplaced tool results.
skip_indices: set[int] = set()
Expand All @@ -309,7 +323,7 @@ def repair_transcript(entries: list[dict], diagnosis: DiagnosisResult) -> list[d

for tc in broken_in_msg:
tc_id = tc["id"]
tc_name = tc.get("function", {}).get("name", "")
tc_name = tc.get("tool", tc.get("function", {}).get("name", ""))
result.append(_make_synthetic_tool_result(tc_id, tc_name))

# 5. If ALL tool_calls from this message were broken AND
Expand All @@ -324,12 +338,15 @@ def repair_transcript(entries: list[dict], diagnosis: DiagnosisResult) -> list[d
if next_entry is None or is_real_user_message(next_entry):
result.append(_make_synthetic_assistant_response())

# 6. After tool results at incomplete_after_lines: inject synthetic
# assistant response. (entry still has line_num; only the
# appended copy is stripped.)
elif (
entry.get("role") == "tool"
and entry.get("line_num") in incomplete_after_lines
# 6. After tool results at incomplete turn positions: inject synthetic
# assistant response. Match by entry index (stable) with line_num
# fallback for old diagnosis results.
elif entry.get("role") == "tool" and (
idx in incomplete_after_indices
or (
incomplete_after_lines
and entry.get("line_num") in incomplete_after_lines
)
):
result.append(_make_synthetic_assistant_response())

Expand Down Expand Up @@ -368,19 +385,27 @@ def rewind_transcript(entries: list[dict], diagnosis: DiagnosisResult) -> list[d
if tc_id in index["tool_uses"]:
issue_indices.append(index["tool_uses"][tc_id]["entry_index"])

# Incomplete turns → walk back from after_line to find assistant with tool_calls
# Incomplete turns → walk back to find assistant with tool_calls
for turn in diagnosis["incomplete_turns"]:
after_line = turn["after_line"]
# Find the entry with this line_num, then walk backwards
for idx in range(len(entries) - 1, -1, -1):
if entries[idx].get("line_num") == after_line:
# Walk backwards from here to find the assistant with tool_calls
for back_idx in range(idx, -1, -1):
e = entries[back_idx]
if e.get("role") == "assistant" and "tool_calls" in e:
issue_indices.append(back_idx)
break
break
after_index = turn.get("after_index")
if after_index is not None:
# Use entry index directly (stable, no line_num dependency)
for back_idx in range(after_index, -1, -1):
e = entries[back_idx]
if e.get("role") == "assistant" and "tool_calls" in e:
issue_indices.append(back_idx)
break
else:
# Fallback to line_num for old diagnosis results
after_line = turn["after_line"]
for idx in range(len(entries) - 1, -1, -1):
if entries[idx].get("line_num") == after_line:
for back_idx in range(idx, -1, -1):
e = entries[back_idx]
if e.get("role") == "assistant" and "tool_calls" in e:
issue_indices.append(back_idx)
break
break

if not issue_indices:
# Diagnosis says broken but referenced IDs weren't found in entries;
Expand Down
112 changes: 111 additions & 1 deletion tests/test_session_diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ def _make_entries_with_lines(messages: list[dict]) -> list[dict]:


def _tc(tc_id: str, name: str) -> dict:
"""Build a tool_call dict."""
"""Build a tool_call dict (OpenAI format)."""
return {"id": tc_id, "function": {"name": name}}


def _tc_amplifier(tc_id: str, name: str) -> dict:
"""Build a tool_call dict (Amplifier format — 'tool' key, no 'function' wrapper)."""
return {"id": tc_id, "tool": name, "arguments": {}}


def _tool_result(tc_id: str, name: str, content: str) -> dict:
"""Build a tool result dict."""
return {"role": "tool", "tool_call_id": tc_id, "name": name, "content": content}
Expand Down Expand Up @@ -167,6 +172,27 @@ def test_no_tool_calls_is_healthy(self):
result = diagnose_transcript(entries)
assert result["status"] == "healthy"

def test_detects_partial_orphan_incomplete_turn(self):
"""Partial orphan (some results present, some missing) also detects incomplete turn."""
entries = _make_entries_with_lines(
[
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"tool_calls": [_tc("call_1", "tool_a"), _tc("call_2", "tool_b")],
},
_tool_result("call_1", "tool_a", "ok"), # tc_1 completed
# tc_2 never completed — crash happened here
]
)
result = diagnose_transcript(entries)
assert result["status"] == "broken"
assert "missing_tool_results" in result["failure_modes"]
assert "incomplete_assistant_turn" in result["failure_modes"]
assert result["orphaned_tool_ids"] == ["call_2"]
assert len(result["incomplete_turns"]) == 1
assert result["incomplete_turns"][0]["after_index"] == 2

def test_recommended_action(self):
"""recommended_action is 'none' for healthy and 'repair' for broken."""
healthy_entries = _make_entries_with_lines(
Expand Down Expand Up @@ -347,6 +373,90 @@ def test_orphans_roundtrip(self):
re_diagnosis = diagnose_transcript(re_entries)
assert re_diagnosis["status"] == "healthy"

def test_partial_orphan_roundtrip(self):
"""Partial orphan (some results present) is fully healed in one pass.

Regression test: before the fix, repair only injected a synthetic
result for the orphan but not the closing assistant response, requiring
a second diagnose-repair cycle.
"""
entries = _make_entries_with_lines(
[
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"tool_calls": [_tc("call_1", "tool_a"), _tc("call_2", "tool_b")],
},
_tool_result("call_1", "tool_a", "ok"), # tc_1 completed
# tc_2 never completed — crash happened here
]
)
diagnosis = diagnose_transcript(entries)
repaired = repair_transcript(entries, diagnosis)

# Verify structure: should have synthetic result for tc_2 AND closing response
tool_results = [e for e in repaired if e.get("role") == "tool"]
assert len(tool_results) == 2 # real + synthetic
assistant_responses = [
e
for e in repaired
if e.get("role") == "assistant" and "tool_calls" not in e
]
assert len(assistant_responses) >= 1 # closing response injected

# Re-diagnose: must be healthy in ONE pass
re_entries = _make_entries_with_lines(repaired)
re_diagnosis = diagnose_transcript(re_entries)
assert re_diagnosis["status"] == "healthy"

def test_partial_orphan_roundtrip_without_line_num(self):
"""Partial orphan repair works even without line_num on entries.

The app-CLI loads transcripts without line_num. After a first repair
pass strips line_num, any subsequent diagnosis must still work via
after_index rather than after_line.
"""
entries = [
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"tool_calls": [_tc("call_1", "tool_a"), _tc("call_2", "tool_b")],
},
{
"role": "tool",
"tool_call_id": "call_1",
"name": "tool_a",
"content": "ok",
},
]
diagnosis = diagnose_transcript(entries)
repaired = repair_transcript(entries, diagnosis)
# Re-diagnose WITHOUT adding line_num back (simulating app-CLI path)
re_diagnosis = diagnose_transcript(repaired)
assert re_diagnosis["status"] == "healthy"

def test_amplifier_format_tool_calls(self):
"""Tool calls using Amplifier format ('tool' key) are indexed and repaired correctly."""
entries = _make_entries_with_lines(
[
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"tool_calls": [_tc_amplifier("call_1", "bash")],
},
]
)
# Index should find the tool name
index = build_tool_index(entries)
assert index["tool_uses"]["call_1"]["tool_name"] == "bash"

# Repair should use the correct name in synthetic results
diagnosis = diagnose_transcript(entries)
repaired = repair_transcript(entries, diagnosis)
synthetic = [e for e in repaired if e.get("role") == "tool"]
assert len(synthetic) == 1
assert synthetic[0]["name"] == "bash"

def test_combined_failures_roundtrip(self):
"""diagnose -> repair -> re-diagnose for ordering+orphans produces healthy."""
tc_orphan = "call_orphan"
Expand Down
Loading