diff --git a/amplifier_foundation/session/diagnosis.py b/amplifier_foundation/session/diagnosis.py index e1da032..f1a52bc 100644 --- a/amplifier_foundation/session/diagnosis.py +++ b/amplifier_foundation/session/diagnosis.py @@ -23,6 +23,7 @@ class IncompleteTurn(TypedDict): after_line: int + after_index: int missing: str @@ -84,7 +85,9 @@ def build_tool_index(entries: list[dict]) -> dict: if entry.get("role") == "assistant" and "tool_calls" in entry: for tool_call in entry["tool_calls"]: call_id = tool_call.get("id", "") - tool_name = tool_call.get("function", {}).get("name", "") + tool_name = tool_call.get( + "tool", tool_call.get("function", {}).get("name", "") + ) tool_uses[call_id] = { "line_num": entry.get("line_num"), "tool_name": tool_name, @@ -176,9 +179,11 @@ def diagnose_transcript(entries: list[dict]) -> DiagnosisResult: if tc_id != first_tc_id: continue - # Skip if any tool_call from this assistant message is orphaned/misplaced + # Skip if ALL tool_calls from this message are orphaned/misplaced + # (step 5 of repair handles those). For partial orphans, continue + # to FM3 detection so the incomplete turn is caught in one pass. all_tc_ids = [tc["id"] for tc in assistant_entry.get("tool_calls", [])] - if any(tid in orphaned_set or tid in misplaced_set for tid in all_tc_ids): + if all(tid in orphaned_set or tid in misplaced_set for tid in all_tc_ids): continue # Find the last tool_result for this assistant message's tool_calls @@ -199,6 +204,7 @@ def diagnose_transcript(entries: list[dict]) -> DiagnosisResult: incomplete_turns.append( { "after_line": entries[last_result_idx].get("line_num"), + "after_index": last_result_idx, "missing": "assistant_response", } ) @@ -210,6 +216,7 @@ def diagnose_transcript(entries: list[dict]) -> DiagnosisResult: incomplete_turns.append( { "after_line": entries[last_result_idx].get("line_num"), + "after_index": last_result_idx, "missing": "assistant_response", } ) @@ -283,7 +290,14 @@ def repair_transcript(entries: list[dict], diagnosis: DiagnosisResult) -> list[d orphaned_set = set(diagnosis["orphaned_tool_ids"]) misplaced_set = set(diagnosis["misplaced_tool_ids"]) broken_set = orphaned_set | misplaced_set - incomplete_after_lines = {t["after_line"] for t in diagnosis["incomplete_turns"]} + # Use entry indices (stable) for matching, with line_num fallback for + # diagnosis results produced before after_index was added. + incomplete_after_indices = { + t["after_index"] for t in diagnosis["incomplete_turns"] if "after_index" in t + } + incomplete_after_lines = { + t["after_line"] for t in diagnosis["incomplete_turns"] if "after_index" not in t + } # 2. Build skip_indices — entry indices of misplaced tool results. skip_indices: set[int] = set() @@ -309,7 +323,7 @@ def repair_transcript(entries: list[dict], diagnosis: DiagnosisResult) -> list[d for tc in broken_in_msg: tc_id = tc["id"] - tc_name = tc.get("function", {}).get("name", "") + tc_name = tc.get("tool", tc.get("function", {}).get("name", "")) result.append(_make_synthetic_tool_result(tc_id, tc_name)) # 5. If ALL tool_calls from this message were broken AND @@ -324,12 +338,15 @@ def repair_transcript(entries: list[dict], diagnosis: DiagnosisResult) -> list[d if next_entry is None or is_real_user_message(next_entry): result.append(_make_synthetic_assistant_response()) - # 6. After tool results at incomplete_after_lines: inject synthetic - # assistant response. (entry still has line_num; only the - # appended copy is stripped.) - elif ( - entry.get("role") == "tool" - and entry.get("line_num") in incomplete_after_lines + # 6. After tool results at incomplete turn positions: inject synthetic + # assistant response. Match by entry index (stable) with line_num + # fallback for old diagnosis results. + elif entry.get("role") == "tool" and ( + idx in incomplete_after_indices + or ( + incomplete_after_lines + and entry.get("line_num") in incomplete_after_lines + ) ): result.append(_make_synthetic_assistant_response()) @@ -368,19 +385,27 @@ def rewind_transcript(entries: list[dict], diagnosis: DiagnosisResult) -> list[d if tc_id in index["tool_uses"]: issue_indices.append(index["tool_uses"][tc_id]["entry_index"]) - # Incomplete turns → walk back from after_line to find assistant with tool_calls + # Incomplete turns → walk back to find assistant with tool_calls for turn in diagnosis["incomplete_turns"]: - after_line = turn["after_line"] - # Find the entry with this line_num, then walk backwards - for idx in range(len(entries) - 1, -1, -1): - if entries[idx].get("line_num") == after_line: - # Walk backwards from here to find the assistant with tool_calls - for back_idx in range(idx, -1, -1): - e = entries[back_idx] - if e.get("role") == "assistant" and "tool_calls" in e: - issue_indices.append(back_idx) - break - break + after_index = turn.get("after_index") + if after_index is not None: + # Use entry index directly (stable, no line_num dependency) + for back_idx in range(after_index, -1, -1): + e = entries[back_idx] + if e.get("role") == "assistant" and "tool_calls" in e: + issue_indices.append(back_idx) + break + else: + # Fallback to line_num for old diagnosis results + after_line = turn["after_line"] + for idx in range(len(entries) - 1, -1, -1): + if entries[idx].get("line_num") == after_line: + for back_idx in range(idx, -1, -1): + e = entries[back_idx] + if e.get("role") == "assistant" and "tool_calls" in e: + issue_indices.append(back_idx) + break + break if not issue_indices: # Diagnosis says broken but referenced IDs weren't found in entries; diff --git a/tests/test_session_diagnosis.py b/tests/test_session_diagnosis.py index cb5f0e8..576ddf9 100644 --- a/tests/test_session_diagnosis.py +++ b/tests/test_session_diagnosis.py @@ -25,10 +25,15 @@ def _make_entries_with_lines(messages: list[dict]) -> list[dict]: def _tc(tc_id: str, name: str) -> dict: - """Build a tool_call dict.""" + """Build a tool_call dict (OpenAI format).""" return {"id": tc_id, "function": {"name": name}} +def _tc_amplifier(tc_id: str, name: str) -> dict: + """Build a tool_call dict (Amplifier format — 'tool' key, no 'function' wrapper).""" + return {"id": tc_id, "tool": name, "arguments": {}} + + def _tool_result(tc_id: str, name: str, content: str) -> dict: """Build a tool result dict.""" return {"role": "tool", "tool_call_id": tc_id, "name": name, "content": content} @@ -167,6 +172,27 @@ def test_no_tool_calls_is_healthy(self): result = diagnose_transcript(entries) assert result["status"] == "healthy" + def test_detects_partial_orphan_incomplete_turn(self): + """Partial orphan (some results present, some missing) also detects incomplete turn.""" + entries = _make_entries_with_lines( + [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "tool_calls": [_tc("call_1", "tool_a"), _tc("call_2", "tool_b")], + }, + _tool_result("call_1", "tool_a", "ok"), # tc_1 completed + # tc_2 never completed — crash happened here + ] + ) + result = diagnose_transcript(entries) + assert result["status"] == "broken" + assert "missing_tool_results" in result["failure_modes"] + assert "incomplete_assistant_turn" in result["failure_modes"] + assert result["orphaned_tool_ids"] == ["call_2"] + assert len(result["incomplete_turns"]) == 1 + assert result["incomplete_turns"][0]["after_index"] == 2 + def test_recommended_action(self): """recommended_action is 'none' for healthy and 'repair' for broken.""" healthy_entries = _make_entries_with_lines( @@ -347,6 +373,90 @@ def test_orphans_roundtrip(self): re_diagnosis = diagnose_transcript(re_entries) assert re_diagnosis["status"] == "healthy" + def test_partial_orphan_roundtrip(self): + """Partial orphan (some results present) is fully healed in one pass. + + Regression test: before the fix, repair only injected a synthetic + result for the orphan but not the closing assistant response, requiring + a second diagnose-repair cycle. + """ + entries = _make_entries_with_lines( + [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "tool_calls": [_tc("call_1", "tool_a"), _tc("call_2", "tool_b")], + }, + _tool_result("call_1", "tool_a", "ok"), # tc_1 completed + # tc_2 never completed — crash happened here + ] + ) + diagnosis = diagnose_transcript(entries) + repaired = repair_transcript(entries, diagnosis) + + # Verify structure: should have synthetic result for tc_2 AND closing response + tool_results = [e for e in repaired if e.get("role") == "tool"] + assert len(tool_results) == 2 # real + synthetic + assistant_responses = [ + e + for e in repaired + if e.get("role") == "assistant" and "tool_calls" not in e + ] + assert len(assistant_responses) >= 1 # closing response injected + + # Re-diagnose: must be healthy in ONE pass + re_entries = _make_entries_with_lines(repaired) + re_diagnosis = diagnose_transcript(re_entries) + assert re_diagnosis["status"] == "healthy" + + def test_partial_orphan_roundtrip_without_line_num(self): + """Partial orphan repair works even without line_num on entries. + + The app-CLI loads transcripts without line_num. After a first repair + pass strips line_num, any subsequent diagnosis must still work via + after_index rather than after_line. + """ + entries = [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "tool_calls": [_tc("call_1", "tool_a"), _tc("call_2", "tool_b")], + }, + { + "role": "tool", + "tool_call_id": "call_1", + "name": "tool_a", + "content": "ok", + }, + ] + diagnosis = diagnose_transcript(entries) + repaired = repair_transcript(entries, diagnosis) + # Re-diagnose WITHOUT adding line_num back (simulating app-CLI path) + re_diagnosis = diagnose_transcript(repaired) + assert re_diagnosis["status"] == "healthy" + + def test_amplifier_format_tool_calls(self): + """Tool calls using Amplifier format ('tool' key) are indexed and repaired correctly.""" + entries = _make_entries_with_lines( + [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "tool_calls": [_tc_amplifier("call_1", "bash")], + }, + ] + ) + # Index should find the tool name + index = build_tool_index(entries) + assert index["tool_uses"]["call_1"]["tool_name"] == "bash" + + # Repair should use the correct name in synthetic results + diagnosis = diagnose_transcript(entries) + repaired = repair_transcript(entries, diagnosis) + synthetic = [e for e in repaired if e.get("role") == "tool"] + assert len(synthetic) == 1 + assert synthetic[0]["name"] == "bash" + def test_combined_failures_roundtrip(self): """diagnose -> repair -> re-diagnose for ordering+orphans produces healthy.""" tc_orphan = "call_orphan"