From 2d48bdfab17d83bbbcaed34956196b9749d791f7 Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Wed, 3 Dec 2025 23:30:40 -0500 Subject: [PATCH 01/13] Add dynamic agent-mode query pipeline * Introduced ContextRegistry, AgentOrchestrator, and AgentToolkit to run a budgeted investigation and synthesis loop over existing artifacts. * Wired agent mode into config and CLI with reasoning and tool limits, plus a new model and tools for search, navigational reading, grep, and section summaries. * Current agent loop is a first cut with simple prompts, no registry scoring or eviction policy, and some hardcoded paths and basic error handling. --- config/config.yaml | 7 +- src/agent/__init__.py | 5 + src/agent/context_manager.py | 50 ++++++ src/agent/orchestrator.py | 303 +++++++++++++++++++++++++++++++++ src/agent/tools.py | 316 +++++++++++++++++++++++++++++++++++ src/config.py | 5 + src/main.py | 72 +++++++- 7 files changed, 756 insertions(+), 2 deletions(-) create mode 100644 src/agent/__init__.py create mode 100644 src/agent/context_manager.py create mode 100644 src/agent/orchestrator.py create mode 100644 src/agent/tools.py diff --git a/config/config.yaml b/config/config.yaml index 0905f546..578ea5a5 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -12,4 +12,9 @@ chunk_overlap: 200 use_hyde: false hyde_max_tokens: 300 use_indexed_chunks: false -rerank_mode: "cross_encoder" \ No newline at end of file +rerank_mode: "cross_encoder" + +# Agent mode settings +use_agent: false +agent_reasoning_limit: 5 +agent_tool_limit: 20 diff --git a/src/agent/__init__.py b/src/agent/__init__.py new file mode 100644 index 00000000..285e6e13 --- /dev/null +++ b/src/agent/__init__.py @@ -0,0 +1,5 @@ +from src.agent.context_manager import ContextRegistry +from src.agent.orchestrator import AgentOrchestrator + +__all__ = ["ContextRegistry", "AgentOrchestrator"] + diff --git a/src/agent/context_manager.py b/src/agent/context_manager.py new file mode 100644 index 00000000..0271115a --- /dev/null +++ b/src/agent/context_manager.py @@ -0,0 +1,50 @@ +""" +Context registry for managing observations during agent investigation. +""" + +from typing import Dict, List, Optional + + +class ContextRegistry: + """Keyed registry for agent observations. Each tool execution returns a ref_id.""" + + def __init__(self): + self._observations: Dict[str, str] = {} + self._counter: int = 0 + + def add_observation(self, text: str) -> str: + """Add an observation and return its ref_id.""" + self._counter += 1 + ref_id = f"obs_{self._counter}" + self._observations[ref_id] = text + return ref_id + + def get(self, ref_id: str) -> Optional[str]: + """Get a single observation by ref_id.""" + return self._observations.get(ref_id) + + def get_context(self, keep_ids: List[str]) -> str: + """Return concatenated context for the specified ref_ids.""" + parts = [] + for ref_id in keep_ids: + if ref_id in self._observations: + parts.append(f"[{ref_id}]\n{self._observations[ref_id]}") + return "\n\n".join(parts) + + def prune(self, discard_ids: List[str]) -> None: + """Remove observations by ref_id.""" + for ref_id in discard_ids: + self._observations.pop(ref_id, None) + + def list_ids(self) -> List[str]: + """Return all current observation ref_ids.""" + return list(self._observations.keys()) + + def clear(self) -> None: + """Clear all observations.""" + self._observations.clear() + self._counter = 0 + + def __len__(self) -> int: + return len(self._observations) + diff --git a/src/agent/orchestrator.py b/src/agent/orchestrator.py new file mode 100644 index 00000000..226a8043 --- /dev/null +++ b/src/agent/orchestrator.py @@ -0,0 +1,303 @@ +""" +Agent orchestrator for the dynamic context budgeted agent. + +Manages the investigation loop: +1. Investigation phase: SLM queries tools and manages context registry +2. Synthesis phase: Generate final answer from curated context +""" + +import json +import re +from dataclasses import dataclass +from typing import Optional, Dict, List, Any + +from src.agent.context_manager import ContextRegistry +from src.agent.tools import AgentToolkit +from src.generator import get_llama_model, ANSWER_END + + +@dataclass +class AgentConfig: + reasoning_limit: int = 5 + tool_limit: int = 20 + max_reasoning_tokens: int = 500 + max_generation_tokens: int = 400 + + +@dataclass +class AgentStep: + thought: str + tool_name: Optional[str] + tool_args: Dict[str, Any] + context_action: Dict[str, Any] + signal: str + + +AGENT_SYSTEM_PROMPT = """You are an investigative agent that retrieves information to answer questions. + +You work in a loop: think → use a tool → observe → repeat until ready. + +{tool_descriptions} + +## Output Format (strict JSON) +```json +{{ + "thought": "Your reasoning about current state and next steps", + "tool_name": "name_of_tool or null if done", + "tool_args": {{"arg1": "value1"}}, + "context_action": {{ + "keep": ["obs_1", "obs_3"], + "discard": ["obs_2"], + "notes": "Why keeping these" + }}, + "signal": "continue or finish" +}} +``` + +## Rules +- Use search_index first to find relevant chunks +- Use read_content to get full text of promising chunks +- Use grep_text for exact matches (code, variables, specific terms) +- Signal "finish" when you have enough information +- Keep only observations needed for the final answer +- Discard observations that are irrelevant or redundant + +Current observations in registry: {observation_ids} +""" + +SYNTHESIS_PROMPT = """Based on the following curated context, answer the question concisely. + +Context: +{context} + +Question: {question} + +Answer:""" + + +def parse_agent_response(text: str) -> Optional[AgentStep]: + """Extract JSON from agent response.""" + text = text.strip() + + # Try to find JSON in markdown code block + json_match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL) + if json_match: + json_str = json_match.group(1).strip() + else: + # Try to extract raw JSON object + json_match = re.search(r"\{.*\}", text, re.DOTALL) + if json_match: + json_str = json_match.group(0) + else: + return None + + try: + data = json.loads(json_str) + except json.JSONDecodeError: + return None + + return AgentStep( + thought=data.get("thought", ""), + tool_name=data.get("tool_name"), + tool_args=data.get("tool_args", {}), + context_action=data.get("context_action", {}), + signal=data.get("signal", "continue"), + ) + + +class AgentOrchestrator: + """Main agent loop coordinating tools, context, and LLM calls.""" + + def __init__( + self, + toolkit: AgentToolkit, + model_path: str, + config: Optional[AgentConfig] = None, + ): + self.toolkit = toolkit + self.model_path = model_path + self.config = config or AgentConfig() + self.registry = ContextRegistry() + + def _build_investigation_prompt(self, question: str, history: List[str]) -> str: + """Build prompt for investigation step.""" + system = AGENT_SYSTEM_PROMPT.format( + tool_descriptions=AgentToolkit.get_tool_descriptions(), + observation_ids=self.registry.list_ids() or "[]", + ) + + history_text = "\n".join(history) if history else "No history yet." + + return f"""<|im_start|>system +{system} +<|im_end|> +<|im_start|>user +Question: {question} + +Investigation history: +{history_text} + +What's your next step? +<|im_end|> +<|im_start|>assistant +```json +""" + + def _run_reasoning_step(self, prompt: str) -> str: + """Run a single LLM call for reasoning.""" + model = get_llama_model(self.model_path) + result = model.create_completion( + prompt, + max_tokens=self.config.max_reasoning_tokens, + temperature=0.1, + stop=["```\n", "<|im_end|>"], + ) + return result["choices"][0]["text"] + + def _apply_context_action(self, action: Dict[str, Any]) -> None: + """Apply keep/discard actions to the registry.""" + discard_ids = action.get("discard", []) + if discard_ids: + self.registry.prune(discard_ids) + + def investigate(self, question: str) -> List[str]: + """ + Run investigation phase. + Returns list of observation IDs to use for synthesis. + """ + history: List[str] = [] + reasoning_count = 0 + tool_count = 0 + + while reasoning_count < self.config.reasoning_limit: + reasoning_count += 1 + + prompt = self._build_investigation_prompt(question, history) + response = self._run_reasoning_step(prompt) + + step = parse_agent_response(response) + if step is None: + history.append(f"Step {reasoning_count}: [Parse error] {response[:200]}") + continue + + history.append(f"Step {reasoning_count}: {step.thought}") + + if step.signal == "finish" or step.tool_name is None: + keep_ids = step.context_action.get("keep", self.registry.list_ids()) + return keep_ids + + if tool_count >= self.config.tool_limit: + history.append(f"Tool limit ({self.config.tool_limit}) reached.") + return step.context_action.get("keep", self.registry.list_ids()) + + tool_count += 1 + observation = self.toolkit.execute(step.tool_name, step.tool_args) + ref_id = self.registry.add_observation(observation) + history.append(f" Tool: {step.tool_name}({step.tool_args}) → {ref_id}") + + self._apply_context_action(step.context_action) + + return self.registry.list_ids() + + def synthesize(self, question: str, keep_ids: List[str]) -> str: + """Generate final answer from curated context.""" + context = self.registry.get_context(keep_ids) + + prompt = f"""<|im_start|>system +You are a helpful assistant. Answer questions based on the provided context. +<|im_end|> +<|im_start|>user +{SYNTHESIS_PROMPT.format(context=context, question=question)} +<|im_end|> +<|im_start|>assistant +""" + + model = get_llama_model(self.model_path) + result = model.create_completion( + prompt, + max_tokens=self.config.max_generation_tokens, + temperature=0.2, + stop=[ANSWER_END, "<|im_end|>"], + ) + return result["choices"][0]["text"].strip() + + def run(self, question: str) -> Dict[str, Any]: + """ + Full agent run: investigate → synthesize. + Returns dict with answer, observations, and metadata. + """ + self.registry.clear() + + keep_ids = self.investigate(question) + answer = self.synthesize(question, keep_ids) + + return { + "answer": answer, + "kept_observations": keep_ids, + "total_observations": len(self.registry), + "context_used": self.registry.get_context(keep_ids), + } + + def stream_run(self, question: str): + """ + Generator version of run for streaming output. + Yields status updates during investigation, then final answer. + """ + self.registry.clear() + + yield {"type": "status", "message": "Starting investigation..."} + + history: List[str] = [] + reasoning_count = 0 + tool_count = 0 + keep_ids = [] + + while reasoning_count < self.config.reasoning_limit: + reasoning_count += 1 + + prompt = self._build_investigation_prompt(question, history) + response = self._run_reasoning_step(prompt) + + step = parse_agent_response(response) + if step is None: + history.append(f"Step {reasoning_count}: [Parse error] {response[:100]}") + yield {"type": "status", "message": f"Parse error at step {reasoning_count}, retrying..."} + continue + + yield {"type": "thought", "step": reasoning_count, "thought": step.thought} + history.append(f"Step {reasoning_count}: {step.thought}") + + if step.signal == "finish" or step.tool_name is None: + keep_ids = step.context_action.get("keep", self.registry.list_ids()) + break + + if tool_count >= self.config.tool_limit: + keep_ids = step.context_action.get("keep", self.registry.list_ids()) + break + + tool_count += 1 + yield { + "type": "tool", + "tool_name": step.tool_name, + "tool_args": step.tool_args, + } + + observation = self.toolkit.execute(step.tool_name, step.tool_args) + ref_id = self.registry.add_observation(observation) + history.append(f" Tool: {step.tool_name} → {ref_id}") + + self._apply_context_action(step.context_action) + + if not keep_ids: + keep_ids = self.registry.list_ids() + + yield {"type": "status", "message": "Generating answer..."} + + answer = self.synthesize(question, keep_ids) + + yield { + "type": "answer", + "answer": answer, + "kept_observations": keep_ids, + } + diff --git a/src/agent/tools.py b/src/agent/tools.py new file mode 100644 index 00000000..10db4143 --- /dev/null +++ b/src/agent/tools.py @@ -0,0 +1,316 @@ +""" +Agent tools for dynamic context retrieval. + +Tools: +- IndexScout: Semantic search returning metadata (chunk IDs, relevance) +- NavigationalReader: Read chunk slices with relative offsets +- GrepSearch: Regex search across raw markdown +- SectionSummarizer: Get section content by name +""" + +import json +import re +from dataclasses import dataclass +from pathlib import Path +from typing import List, Dict, Optional, Tuple + +import faiss +import numpy as np + +from src.embedder import SentenceTransformer +from src.retriever import _get_embedder + + +@dataclass +class ChunkMetadata: + chunk_id: int + score: float + source: str + preview: str + + +@dataclass +class GrepMatch: + line_number: int + content: str + context_before: List[str] + context_after: List[str] + + +class IndexScout: + """Semantic search that returns metadata without full text.""" + + def __init__( + self, + faiss_index: faiss.Index, + chunks: List[str], + sources: List[str], + embed_model: str, + ): + self.faiss_index = faiss_index + self.chunks = chunks + self.sources = sources + self.embedder = _get_embedder(embed_model) + + def search_index(self, query: str, top_k: int = 10) -> List[ChunkMetadata]: + """Search index and return metadata list without full chunk text.""" + q_vec = self.embedder.encode([query]).astype("float32") + + if q_vec.shape[1] != self.faiss_index.d: + raise ValueError( + f"Embedding dim mismatch: index={self.faiss_index.d} vs query={q_vec.shape[1]}" + ) + + distances, indices = self.faiss_index.search(q_vec, top_k) + + results = [] + for idx, dist in zip(indices[0], distances[0]): + if idx < 0 or idx >= len(self.chunks): + continue + score = 1.0 / (1.0 + float(dist)) + preview = self.chunks[idx][:100].replace("\n", " ") + results.append( + ChunkMetadata( + chunk_id=int(idx), + score=score, + source=self.sources[idx] if idx < len(self.sources) else "unknown", + preview=preview, + ) + ) + return results + + def format_result(self, results: List[ChunkMetadata]) -> str: + """Format search results as readable text for the agent.""" + if not results: + return "No results found." + lines = [] + for r in results: + lines.append( + f"Chunk {r.chunk_id} (score={r.score:.3f}, source={r.source}): {r.preview}..." + ) + return "\n".join(lines) + + +class NavigationalReader: + """Read chunks with relative offset navigation.""" + + def __init__(self, chunks: List[str], sources: List[str]): + self.chunks = chunks + self.sources = sources + + def read_content( + self, + target_chunk_id: int, + relative_start: int = 0, + relative_end: int = 0, + ) -> Tuple[str, List[int]]: + """ + Fetch chunks[target + start : target + end + 1]. + Returns (concatenated_text, list_of_chunk_ids). + """ + start_idx = max(0, target_chunk_id + relative_start) + end_idx = min(len(self.chunks), target_chunk_id + relative_end + 1) + + if start_idx >= len(self.chunks) or end_idx <= start_idx: + return "", [] + + chunk_ids = list(range(start_idx, end_idx)) + texts = [] + for cid in chunk_ids: + src = self.sources[cid] if cid < len(self.sources) else "unknown" + texts.append(f"[Chunk {cid} | {src}]\n{self.chunks[cid]}") + + return "\n\n".join(texts), chunk_ids + + def format_result(self, text: str, chunk_ids: List[int]) -> str: + """Format for agent consumption.""" + if not text: + return "No content found for the specified range." + return f"Read chunks {chunk_ids}:\n{text}" + + +class GrepSearch: + """Regex search across raw markdown content.""" + + def __init__(self, markdown_path: str): + self.markdown_path = Path(markdown_path) + self._lines: Optional[List[str]] = None + + def _load_lines(self) -> List[str]: + if self._lines is None: + with open(self.markdown_path, "r", encoding="utf-8") as f: + self._lines = f.readlines() + return self._lines + + def grep_text( + self, + pattern: str, + context_lines: int = 2, + max_matches: int = 10, + ) -> List[GrepMatch]: + """ + Search for pattern in markdown file. + Returns matches with surrounding context. + """ + lines = self._load_lines() + compiled = re.compile(pattern, re.IGNORECASE) + matches = [] + + for i, line in enumerate(lines): + if compiled.search(line): + start = max(0, i - context_lines) + end = min(len(lines), i + context_lines + 1) + matches.append( + GrepMatch( + line_number=i + 1, + content=line.rstrip(), + context_before=[l.rstrip() for l in lines[start:i]], + context_after=[l.rstrip() for l in lines[i + 1 : end]], + ) + ) + if len(matches) >= max_matches: + break + + return matches + + def format_result(self, matches: List[GrepMatch]) -> str: + """Format grep results for agent.""" + if not matches: + return "No matches found." + lines = [] + for m in matches: + lines.append(f"Line {m.line_number}: {m.content}") + if m.context_before: + lines.append(f" Before: {' | '.join(m.context_before[-2:])}") + if m.context_after: + lines.append(f" After: {' | '.join(m.context_after[:2])}") + return "\n".join(lines) + + +class SectionSummarizer: + """Retrieve section content from extracted_sections.json.""" + + def __init__(self, sections_path: str, max_chars: int = 1000): + self.sections_path = Path(sections_path) + self.max_chars = max_chars + self._sections: Optional[List[Dict]] = None + + def _load_sections(self) -> List[Dict]: + if self._sections is None: + with open(self.sections_path, "r", encoding="utf-8") as f: + self._sections = json.load(f) + return self._sections + + def get_section_summary(self, section_name: str) -> Optional[Dict]: + """ + Find section by name (case-insensitive partial match). + Returns heading and truncated content. + """ + sections = self._load_sections() + section_name_lower = section_name.lower() + + for section in sections: + heading = section.get("heading", "") + if section_name_lower in heading.lower(): + content = section.get("content", "") + return { + "heading": heading, + "content": content[: self.max_chars], + "full_length": len(content), + } + return None + + def list_sections(self, limit: int = 20) -> List[str]: + """List available section headings.""" + sections = self._load_sections() + return [s.get("heading", "Untitled") for s in sections[:limit]] + + def format_result(self, result: Optional[Dict]) -> str: + """Format section result for agent.""" + if result is None: + return "Section not found." + truncated = "(truncated)" if result["full_length"] > self.max_chars else "" + return f"{result['heading']}\n{result['content']} {truncated}" + + +class AgentToolkit: + """Container for all agent tools, initialized from artifacts.""" + + def __init__( + self, + faiss_index: faiss.Index, + chunks: List[str], + sources: List[str], + embed_model: str, + markdown_path: str, + sections_path: str, + ): + self.index_scout = IndexScout(faiss_index, chunks, sources, embed_model) + self.reader = NavigationalReader(chunks, sources) + self.grep = GrepSearch(markdown_path) + self.summarizer = SectionSummarizer(sections_path) + + def execute(self, tool_name: str, tool_args: Dict) -> str: + """Execute a tool by name with given arguments.""" + if tool_name == "search_index": + results = self.index_scout.search_index( + query=tool_args["query"], + top_k=tool_args.get("top_k", 10), + ) + return self.index_scout.format_result(results) + + elif tool_name == "read_content": + text, chunk_ids = self.reader.read_content( + target_chunk_id=tool_args["target_chunk_id"], + relative_start=tool_args.get("relative_start", 0), + relative_end=tool_args.get("relative_end", 0), + ) + return self.reader.format_result(text, chunk_ids) + + elif tool_name == "grep_text": + matches = self.grep.grep_text( + pattern=tool_args["pattern"], + context_lines=tool_args.get("context_lines", 2), + max_matches=tool_args.get("max_matches", 10), + ) + return self.grep.format_result(matches) + + elif tool_name == "get_section_summary": + result = self.summarizer.get_section_summary( + section_name=tool_args["section_name"] + ) + return self.summarizer.format_result(result) + + elif tool_name == "list_sections": + sections = self.summarizer.list_sections( + limit=tool_args.get("limit", 20) + ) + return "\n".join(sections) + + else: + raise ValueError(f"Unknown tool: {tool_name}") + + @staticmethod + def get_tool_descriptions() -> str: + """Return tool descriptions for the agent prompt.""" + return """Available tools: + +1. search_index(query: str, top_k: int = 10) + - Semantic search returning chunk metadata (IDs, scores, sources, previews) + - Use to find relevant sections before reading full content + +2. read_content(target_chunk_id: int, relative_start: int = 0, relative_end: int = 0) + - Read chunks with relative offsets from target + - Example: target=100, start=-1, end=2 reads chunks 99-102 + +3. grep_text(pattern: str, context_lines: int = 2, max_matches: int = 10) + - Regex search across raw markdown + - Use for exact phrases, variable names, specific terms + +4. get_section_summary(section_name: str) + - Get section content by heading name (partial match) + - Returns truncated content for overview + +5. list_sections(limit: int = 20) + - List available section headings""" + diff --git a/src/config.py b/src/config.py index af173cd1..d19275f3 100644 --- a/src/config.py +++ b/src/config.py @@ -48,6 +48,11 @@ class RAGConfig: extracted_index_path: os.PathLike = "data/extracted_index.json" page_to_chunk_map_path: os.PathLike = "index/sections/textbook_index_page_to_chunk_map.json" + # agent mode + use_agent: bool = False + agent_reasoning_limit: int = 5 + agent_tool_limit: int = 20 + # ---------- factory + validation ---------- @classmethod def from_yaml(cls, path: os.PathLike) -> RAGConfig: diff --git a/src/main.py b/src/main.py index 998d3a11..e7ba807c 100644 --- a/src/main.py +++ b/src/main.py @@ -21,6 +21,9 @@ from rich.console import Console from rich.markdown import Markdown +from src.agent.tools import AgentToolkit +from src.agent.orchestrator import AgentOrchestrator, AgentConfig + ANSWER_NOT_FOUND = "I'm sorry, but I don't have enough information to answer that question." def parse_args() -> argparse.Namespace: @@ -276,6 +279,70 @@ def get_keywords(question: str) -> list: keywords = [word.strip('.,!?()[]') for word in words if word not in stopwords] return keywords + +def run_agent_chat_session(args: argparse.Namespace, cfg: RAGConfig): + """Agent-based chat with dynamic context management.""" + console = Console() + + print("Welcome to Tokensmith (Agent Mode)! Initializing...") + artifacts_dir = cfg.get_artifacts_directory() + faiss_index, bm25_index, chunks, sources, meta = load_artifacts( + artifacts_dir=artifacts_dir, + index_prefix=args.index_prefix + ) + + model_path = args.model_path or cfg.gen_model + + toolkit = AgentToolkit( + faiss_index=faiss_index, + chunks=chunks, + sources=sources, + embed_model=cfg.embed_model, + markdown_path="data/book_with_pages.md", + sections_path="data/extracted_sections.json", + ) + + agent_config = AgentConfig( + reasoning_limit=cfg.agent_reasoning_limit, + tool_limit=cfg.agent_tool_limit, + max_generation_tokens=cfg.max_gen_tokens, + ) + + orchestrator = AgentOrchestrator( + toolkit=toolkit, + model_path=model_path, + config=agent_config, + ) + + print("Initialization complete. Agent mode active.") + print("Type 'exit' or 'quit' to end the session.") + + while True: + try: + q = input("\nAsk > ").strip() + if not q: + continue + if q.lower() in {"exit", "quit"}: + print("Goodbye!") + break + + console.print("\n[dim]Investigating...[/dim]") + for event in orchestrator.stream_run(q): + if event["type"] == "thought": + console.print(f"[dim]Step {event['step']}: {event['thought']}[/dim]") + elif event["type"] == "tool": + console.print(f"[cyan] → {event['tool_name']}[/cyan]") + elif event["type"] == "answer": + console.print("\n[bold cyan]==================== ANSWER ====================[/bold cyan]\n") + console.print(Markdown(event["answer"])) + console.print("\n[bold cyan]=================================================[/bold cyan]\n") + console.print(f"[dim]Used observations: {event['kept_observations']}[/dim]") + + except KeyboardInterrupt: + print("\nGoodbye!") + break + + def run_chat_session(args: argparse.Namespace, cfg: RAGConfig): """ Initializes artifacts and runs the main interactive chat loop. @@ -373,7 +440,10 @@ def main(): if args.mode == "index": run_index_mode(args, cfg) elif args.mode == "chat": - run_chat_session(args, cfg) + if cfg.use_agent: + run_agent_chat_session(args, cfg) + else: + run_chat_session(args, cfg) if __name__ == "__main__": From b84ef4607f888adfccef8807551b79d5c902c578 Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Wed, 3 Dec 2025 23:41:18 -0500 Subject: [PATCH 02/13] Add logging --- src/agent/__init__.py | 3 +- src/agent/logger.py | 99 +++++++++++++++++++++++++++++++++++++++ src/agent/orchestrator.py | 48 +++++++++++++++++-- 3 files changed, 144 insertions(+), 6 deletions(-) create mode 100644 src/agent/logger.py diff --git a/src/agent/__init__.py b/src/agent/__init__.py index 285e6e13..e2b1f96c 100644 --- a/src/agent/__init__.py +++ b/src/agent/__init__.py @@ -1,5 +1,6 @@ from src.agent.context_manager import ContextRegistry from src.agent.orchestrator import AgentOrchestrator +from src.agent.logger import AgentLogger -__all__ = ["ContextRegistry", "AgentOrchestrator"] +__all__ = ["ContextRegistry", "AgentOrchestrator", "AgentLogger"] diff --git a/src/agent/logger.py b/src/agent/logger.py new file mode 100644 index 00000000..3b640889 --- /dev/null +++ b/src/agent/logger.py @@ -0,0 +1,99 @@ +""" +Logging for agent pipeline - captures all LLM inputs/outputs. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional + + +class AgentLogger: + """Logs all LLM interactions in the agent pipeline.""" + + def __init__(self, session_id: Optional[str] = None): + self.session_id = session_id or datetime.now().strftime("%Y%m%d_%H%M%S") + self.logs_dir = Path("logs") / "agent" + self.logs_dir.mkdir(parents=True, exist_ok=True) + self.log_file = self.logs_dir / f"agent_{self.session_id}.jsonl" + self.step_count = 0 + self.query_count = 0 + + def _write(self, data: Dict[str, Any]) -> None: + data["timestamp"] = datetime.now().isoformat() + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(data, ensure_ascii=False) + "\n") + + def log_session_start(self, config: Dict[str, Any]) -> None: + self._write({ + "event": "session_start", + "session_id": self.session_id, + "config": config, + }) + + def log_query_start(self, question: str) -> None: + self.query_count += 1 + self.step_count = 0 + self._write({ + "event": "query_start", + "query_id": self.query_count, + "question": question, + }) + + def log_reasoning_step( + self, + prompt: str, + response: str, + parsed_step: Optional[Dict[str, Any]], + ) -> None: + self.step_count += 1 + self._write({ + "event": "reasoning_step", + "query_id": self.query_count, + "step": self.step_count, + "prompt": prompt, + "response": response, + "parsed": parsed_step, + "parse_success": parsed_step is not None, + }) + + def log_tool_execution( + self, + tool_name: str, + tool_args: Dict[str, Any], + result: str, + ref_id: str, + ) -> None: + self._write({ + "event": "tool_execution", + "query_id": self.query_count, + "step": self.step_count, + "tool_name": tool_name, + "tool_args": tool_args, + "result": result, + "ref_id": ref_id, + }) + + def log_synthesis( + self, + prompt: str, + response: str, + keep_ids: list, + ) -> None: + self._write({ + "event": "synthesis", + "query_id": self.query_count, + "prompt": prompt, + "response": response, + "keep_ids": keep_ids, + }) + + def log_query_complete(self, answer: str, metadata: Dict[str, Any]) -> None: + self._write({ + "event": "query_complete", + "query_id": self.query_count, + "total_steps": self.step_count, + "answer": answer, + "metadata": metadata, + }) + diff --git a/src/agent/orchestrator.py b/src/agent/orchestrator.py index 226a8043..d0f6175a 100644 --- a/src/agent/orchestrator.py +++ b/src/agent/orchestrator.py @@ -13,6 +13,7 @@ from src.agent.context_manager import ContextRegistry from src.agent.tools import AgentToolkit +from src.agent.logger import AgentLogger from src.generator import get_llama_model, ANSWER_END @@ -118,6 +119,12 @@ def __init__( self.model_path = model_path self.config = config or AgentConfig() self.registry = ContextRegistry() + self.logger = AgentLogger() + self.logger.log_session_start({ + "model_path": model_path, + "reasoning_limit": self.config.reasoning_limit, + "tool_limit": self.config.tool_limit, + }) def _build_investigation_prompt(self, question: str, history: List[str]) -> str: """Build prompt for investigation step.""" @@ -199,11 +206,10 @@ def investigate(self, question: str) -> List[str]: return self.registry.list_ids() - def synthesize(self, question: str, keep_ids: List[str]) -> str: - """Generate final answer from curated context.""" + def _build_synthesis_prompt(self, question: str, keep_ids: List[str]) -> str: + """Build prompt for synthesis step.""" context = self.registry.get_context(keep_ids) - - prompt = f"""<|im_start|>system + return f"""<|im_start|>system You are a helpful assistant. Answer questions based on the provided context. <|im_end|> <|im_start|>user @@ -212,6 +218,9 @@ def synthesize(self, question: str, keep_ids: List[str]) -> str: <|im_start|>assistant """ + def synthesize(self, question: str, keep_ids: List[str]) -> str: + """Generate final answer from curated context.""" + prompt = self._build_synthesis_prompt(question, keep_ids) model = get_llama_model(self.model_path) result = model.create_completion( prompt, @@ -221,6 +230,20 @@ def synthesize(self, question: str, keep_ids: List[str]) -> str: ) return result["choices"][0]["text"].strip() + def _synthesize_with_logging(self, question: str, keep_ids: List[str]) -> str: + """Synthesize with logging.""" + prompt = self._build_synthesis_prompt(question, keep_ids) + model = get_llama_model(self.model_path) + result = model.create_completion( + prompt, + max_tokens=self.config.max_generation_tokens, + temperature=0.2, + stop=[ANSWER_END, "<|im_end|>"], + ) + response = result["choices"][0]["text"].strip() + self.logger.log_synthesis(prompt, response, keep_ids) + return response + def run(self, question: str) -> Dict[str, Any]: """ Full agent run: investigate → synthesize. @@ -259,6 +282,15 @@ def stream_run(self, question: str): response = self._run_reasoning_step(prompt) step = parse_agent_response(response) + parsed_dict = { + "thought": step.thought, + "tool_name": step.tool_name, + "tool_args": step.tool_args, + "context_action": step.context_action, + "signal": step.signal, + } if step else None + self.logger.log_reasoning_step(prompt, response, parsed_dict) + if step is None: history.append(f"Step {reasoning_count}: [Parse error] {response[:100]}") yield {"type": "status", "message": f"Parse error at step {reasoning_count}, retrying..."} @@ -284,6 +316,7 @@ def stream_run(self, question: str): observation = self.toolkit.execute(step.tool_name, step.tool_args) ref_id = self.registry.add_observation(observation) + self.logger.log_tool_execution(step.tool_name, step.tool_args, observation, ref_id) history.append(f" Tool: {step.tool_name} → {ref_id}") self._apply_context_action(step.context_action) @@ -293,7 +326,12 @@ def stream_run(self, question: str): yield {"type": "status", "message": "Generating answer..."} - answer = self.synthesize(question, keep_ids) + answer = self._synthesize_with_logging(question, keep_ids) + + self.logger.log_query_complete(answer, { + "kept_observations": keep_ids, + "total_observations": len(self.registry), + }) yield { "type": "answer", From b7e22d040e188b25e38162f1e6fff2a107ad7bd3 Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Wed, 3 Dec 2025 23:49:15 -0500 Subject: [PATCH 03/13] add full obs context instead of preview --- src/agent/orchestrator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/agent/orchestrator.py b/src/agent/orchestrator.py index d0f6175a..512d29e3 100644 --- a/src/agent/orchestrator.py +++ b/src/agent/orchestrator.py @@ -128,15 +128,21 @@ def __init__( def _build_investigation_prompt(self, question: str, history: List[str]) -> str: """Build prompt for investigation step.""" + obs_ids = self.registry.list_ids() + current_obs = self.registry.get_context(obs_ids) if obs_ids else "None yet." + system = AGENT_SYSTEM_PROMPT.format( tool_descriptions=AgentToolkit.get_tool_descriptions(), - observation_ids=self.registry.list_ids() or "[]", + observation_ids=obs_ids or "[]", ) history_text = "\n".join(history) if history else "No history yet." return f"""<|im_start|>system {system} + +Current observations: +{current_obs} <|im_end|> <|im_start|>user Question: {question} From a4ee92bbf6cc21be425d8bce4791f6b22a27b5e4 Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Wed, 3 Dec 2025 21:48:46 -0500 Subject: [PATCH 04/13] add index based scoring --- src/config.py | 1 + src/retriever.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/config.py b/src/config.py index d19275f3..e0976e6f 100644 --- a/src/config.py +++ b/src/config.py @@ -31,6 +31,7 @@ class RAGConfig: # generation max_gen_tokens: int = 400 gen_model: str = "models/qwen2.5-1.5b-instruct-q5_k_m.gguf" + temperature: float = 0.7 # testing system_prompt_mode: str = "baseline" diff --git a/src/retriever.py b/src/retriever.py index 32ac7a80..e314b76d 100644 --- a/src/retriever.py +++ b/src/retriever.py @@ -267,7 +267,6 @@ def _lemmatize_word(word: str, lemmatizer) -> str: @staticmethod def _extract_keywords(query: str) -> List[str]: """Extract keywords from query by removing stopwords and lemmatizing.""" - stopwords = { "the", "is", "at", "which", "on", "for", "a", "an", "and", "or", "in", "to", "of", "by", "with", "that", "this", "it", "as", "are", "was", @@ -282,4 +281,4 @@ def _extract_keywords(query: str) -> List[str]: if not cleaned or cleaned in stopwords: continue keywords.append(IndexKeywordRetriever._lemmatize_word(cleaned, lemmatizer)) - return keywords \ No newline at end of file + return keywords From 1521a8532acb96cc9404b5ebf68d58029ab89a3c Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Thu, 4 Dec 2025 12:33:35 -0500 Subject: [PATCH 05/13] fix merge bugs --- src/retriever.py | 3 +++ tests/test_benchmarks.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/retriever.py b/src/retriever.py index e314b76d..01105916 100644 --- a/src/retriever.py +++ b/src/retriever.py @@ -213,6 +213,8 @@ def __init__(self, extracted_index_path: os.PathLike, page_to_chunk_map_path: os self.phrase_to_pages = {} self.token_to_phrases = {} + print(f"IndexKeywordRetriever initialized with {len(self.phrase_to_pages)} phrases and {len(self.token_to_phrases)} tokens") + if os.path.exists(page_to_chunk_map_path): with open(page_to_chunk_map_path, 'r') as f: self.page_to_chunk_map = json.load(f) @@ -233,6 +235,7 @@ def get_scores(self, query: str, pool_size: int, chunks: List[str]) -> Dict[int, # Get all phrases containing this keyword token matching_phrases = self.token_to_phrases[keyword] + print(f'Found {len(matching_phrases)} matching phrases for keyword: {keyword}') for phrase in matching_phrases: page_numbers = self.phrase_to_pages[phrase] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 36f76c6a..e3eefc0b 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -210,6 +210,9 @@ def get_tokensmith_answer(question, config, golden_chunks=None): use_indexed_chunks=config.get("use_indexed_chunks", False), extracted_index_path=config.get("extracted_index_path", "data/extracted_index.json"), page_to_chunk_map_path=config.get("page_to_chunk_map_path", "index/sections/textbook_index_page_to_chunk_map.json"), + use_agent=config.get("use_agent", False), + agent_reasoning_limit=config.get("agent_reasoning_limit", 5), + agent_tool_limit=config.get("agent_tool_limit", 20), ) # Print status From 4459ebbd8722b830952543acfab46896f8482e36 Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Thu, 4 Dec 2025 15:53:11 -0500 Subject: [PATCH 06/13] WIP summary generation --- src/agent/generate_summaries.py | 130 ++++++++++++++++++++ src/agent/tools.py | 203 ++++++++++++++++++-------------- src/main.py | 2 +- 3 files changed, 243 insertions(+), 92 deletions(-) create mode 100644 src/agent/generate_summaries.py diff --git a/src/agent/generate_summaries.py b/src/agent/generate_summaries.py new file mode 100644 index 00000000..d073b9f7 --- /dev/null +++ b/src/agent/generate_summaries.py @@ -0,0 +1,130 @@ +""" +Generate dense summaries using sliding window recursive approach. +""" + +import json +from pathlib import Path +from typing import List, Dict +from llama_cpp import Llama + +SUMMARY_MODEL_PATH = "models/Qwen3-4B-Instruct-2507-Q5_K_M.gguf" +CHUNK_SIZE = 3000 +OVERLAP = 500 +MAX_SUMMARY_TOKENS = 150 + + +def load_sections(sections_path: str) -> List[Dict]: + with open(sections_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: + """Split text into overlapping chunks.""" + if len(text) <= chunk_size: + return [text] + + chunks = [] + start = 0 + while start < len(text): + end = start + chunk_size + chunk = text[start:end] + chunks.append(chunk) + if end >= len(text): + break + start = end - overlap + return chunks + + +def summarize_chunk(text: str, model: Llama) -> str: + """Generate dense summary of a single chunk.""" + prompt = f"""<|im_start|>system +You are a technical summarizer. Extract only key facts and concepts. Be dense and factual. No fluff. +<|im_end|> +<|im_start|>user +Summarize this text. Include only essential information: + +{text} +<|im_end|> +<|im_start|>assistant +""" + result = model.create_completion( + prompt, + max_tokens=MAX_SUMMARY_TOKENS, + temperature=0.0, + stop=["<|im_end|>"], + ) + return result["choices"][0]["text"].strip() + + +def summarize_recursive(text: str, model: Llama) -> str: + """Recursively summarize large text using sliding window.""" + if len(text) <= CHUNK_SIZE: + return summarize_chunk(text, model) + + chunks = chunk_text(text, CHUNK_SIZE, OVERLAP) + chunk_summaries = [] + + for chunk in chunks: + summary = summarize_chunk(chunk, model) + chunk_summaries.append(summary) + + combined = "\n\n".join(chunk_summaries) + + if len(combined) <= CHUNK_SIZE: + return summarize_chunk(combined, model) + + return summarize_recursive(combined, model) + + +def generate_all_summaries( + sections_path: str, + output_path: str, + model_path: str = SUMMARY_MODEL_PATH, +) -> None: + """Generate summaries for all sections using recursive sliding window.""" + if not Path(model_path).exists(): + raise FileNotFoundError(f"Model not found: {model_path}") + + model = Llama(model_path=model_path, n_ctx=4096, verbose=False) + + sections = load_sections(sections_path) + summaries = [] + + for i, section in enumerate(sections, 1): + heading = section.get("heading", "") + content = section.get("content", "") + + if not content: + summaries.append({ + "heading": heading, + "summary": "", + "content_length": 0, + }) + continue + + print(f"[{i}/{len(sections)}] {heading[:60]}") + + summary = summarize_recursive(content, model) + + summaries.append({ + "heading": heading, + "summary": summary, + "content_length": len(content), + }) + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(summaries, f, indent=2, ensure_ascii=False) + + print(f"\nGenerated {len(summaries)} summaries → {output_path}") + + +if __name__ == "__main__": + generate_all_summaries( + sections_path="data/extracted_sections.json", + output_path="data/section_summaries.json", + model_path=SUMMARY_MODEL_PATH, + ) + diff --git a/src/agent/tools.py b/src/agent/tools.py index 10db4143..02b730d3 100644 --- a/src/agent/tools.py +++ b/src/agent/tools.py @@ -38,7 +38,7 @@ class GrepMatch: class IndexScout: - """Semantic search that returns metadata without full text.""" + """Semantic search that returns structured metadata.""" def __init__( self, @@ -68,7 +68,7 @@ def search_index(self, query: str, top_k: int = 10) -> List[ChunkMetadata]: if idx < 0 or idx >= len(self.chunks): continue score = 1.0 / (1.0 + float(dist)) - preview = self.chunks[idx][:100].replace("\n", " ") + preview = self.chunks[idx][:150].replace("\n", " ") results.append( ChunkMetadata( chunk_id=int(idx), @@ -80,14 +80,13 @@ def search_index(self, query: str, top_k: int = 10) -> List[ChunkMetadata]: return results def format_result(self, results: List[ChunkMetadata]) -> str: - """Format search results as readable text for the agent.""" + """Format as machine-readable structured output.""" if not results: return "No results found." - lines = [] - for r in results: - lines.append( - f"Chunk {r.chunk_id} (score={r.score:.3f}, source={r.source}): {r.preview}..." - ) + lines = ["Search results (use chunk_id for read_content):"] + for i, r in enumerate(results): + lines.append(f" [{i}] chunk_id={r.chunk_id} score={r.score:.3f} source={r.source}") + lines.append(f" preview: {r.preview}") return "\n".join(lines) @@ -118,15 +117,15 @@ def read_content( texts = [] for cid in chunk_ids: src = self.sources[cid] if cid < len(self.sources) else "unknown" - texts.append(f"[Chunk {cid} | {src}]\n{self.chunks[cid]}") + texts.append(f"--- Chunk {cid} (source: {src}) ---\n{self.chunks[cid]}") return "\n\n".join(texts), chunk_ids def format_result(self, text: str, chunk_ids: List[int]) -> str: """Format for agent consumption.""" if not text: - return "No content found for the specified range." - return f"Read chunks {chunk_ids}:\n{text}" + return "ERROR: No content found for specified range." + return f"Content from chunks {chunk_ids}:\n\n{text}" class GrepSearch: @@ -153,7 +152,11 @@ def grep_text( Returns matches with surrounding context. """ lines = self._load_lines() - compiled = re.compile(pattern, re.IGNORECASE) + try: + compiled = re.compile(pattern, re.IGNORECASE) + except re.error as e: + raise ValueError(f"Invalid regex pattern: {pattern} ({e})") + matches = [] for i, line in enumerate(lines): @@ -176,61 +179,70 @@ def grep_text( def format_result(self, matches: List[GrepMatch]) -> str: """Format grep results for agent.""" if not matches: - return "No matches found." - lines = [] + return f"No matches found for pattern." + lines = [f"Found {len(matches)} matches:"] for m in matches: - lines.append(f"Line {m.line_number}: {m.content}") + lines.append(f"\n Line {m.line_number}: {m.content}") if m.context_before: - lines.append(f" Before: {' | '.join(m.context_before[-2:])}") + for ctx in m.context_before[-2:]: + lines.append(f" (before) {ctx}") if m.context_after: - lines.append(f" After: {' | '.join(m.context_after[:2])}") + for ctx in m.context_after[:2]: + lines.append(f" (after) {ctx}") return "\n".join(lines) class SectionSummarizer: - """Retrieve section content from extracted_sections.json.""" - - def __init__(self, sections_path: str, max_chars: int = 1000): - self.sections_path = Path(sections_path) - self.max_chars = max_chars - self._sections: Optional[List[Dict]] = None - - def _load_sections(self) -> List[Dict]: - if self._sections is None: - with open(self.sections_path, "r", encoding="utf-8") as f: - self._sections = json.load(f) - return self._sections + """Retrieve section summaries from generated summaries file.""" + + def __init__(self, summaries_path: str): + self.summaries_path = Path(summaries_path) + self._summaries: Optional[List[Dict]] = None + + def _load_summaries(self) -> List[Dict]: + if self._summaries is None: + if not self.summaries_path.exists(): + raise FileNotFoundError( + f"Summaries file not found: {self.summaries_path}\n" + "Run: python -m src.agent.generate_summaries" + ) + with open(self.summaries_path, "r", encoding="utf-8") as f: + self._summaries = json.load(f) + return self._summaries def get_section_summary(self, section_name: str) -> Optional[Dict]: - """ - Find section by name (case-insensitive partial match). - Returns heading and truncated content. - """ - sections = self._load_sections() + """Find section by name and return its summary.""" + summaries = self._load_summaries() section_name_lower = section_name.lower() - for section in sections: - heading = section.get("heading", "") + for summ in summaries: + heading = summ.get("heading", "") if section_name_lower in heading.lower(): - content = section.get("content", "") return { "heading": heading, - "content": content[: self.max_chars], - "full_length": len(content), + "summary": summ.get("summary", ""), + "content_length": summ.get("content_length", 0), } return None - def list_sections(self, limit: int = 20) -> List[str]: - """List available section headings.""" - sections = self._load_sections() - return [s.get("heading", "Untitled") for s in sections[:limit]] + def list_sections(self, limit: int = 30) -> List[str]: + """List available section headings with summaries.""" + summaries = self._load_summaries() + results = [] + for s in summaries[:limit]: + heading = s.get("heading", "Untitled") + summary = s.get("summary", "") + if summary: + results.append(f"{heading}: {summary[:100]}") + else: + results.append(heading) + return results def format_result(self, result: Optional[Dict]) -> str: - """Format section result for agent.""" + """Format section summary for agent.""" if result is None: - return "Section not found." - truncated = "(truncated)" if result["full_length"] > self.max_chars else "" - return f"{result['heading']}\n{result['content']} {truncated}" + return "ERROR: Section not found. Use list_sections to see available sections." + return f"Section: {result['heading']}\nSummary: {result['summary']}\nFull length: {result['content_length']} chars" class AgentToolkit: @@ -243,52 +255,56 @@ def __init__( sources: List[str], embed_model: str, markdown_path: str, - sections_path: str, + summaries_path: str, ): self.index_scout = IndexScout(faiss_index, chunks, sources, embed_model) self.reader = NavigationalReader(chunks, sources) self.grep = GrepSearch(markdown_path) - self.summarizer = SectionSummarizer(sections_path) + self.summarizer = SectionSummarizer(summaries_path) def execute(self, tool_name: str, tool_args: Dict) -> str: """Execute a tool by name with given arguments.""" - if tool_name == "search_index": - results = self.index_scout.search_index( - query=tool_args["query"], - top_k=tool_args.get("top_k", 10), - ) - return self.index_scout.format_result(results) + try: + if tool_name == "search_index": + results = self.index_scout.search_index( + query=tool_args["query"], + top_k=tool_args.get("top_k", 10), + ) + return self.index_scout.format_result(results) - elif tool_name == "read_content": - text, chunk_ids = self.reader.read_content( - target_chunk_id=tool_args["target_chunk_id"], - relative_start=tool_args.get("relative_start", 0), - relative_end=tool_args.get("relative_end", 0), - ) - return self.reader.format_result(text, chunk_ids) + elif tool_name == "read_content": + text, chunk_ids = self.reader.read_content( + target_chunk_id=tool_args["target_chunk_id"], + relative_start=tool_args.get("relative_start", 0), + relative_end=tool_args.get("relative_end", 0), + ) + return self.reader.format_result(text, chunk_ids) - elif tool_name == "grep_text": - matches = self.grep.grep_text( - pattern=tool_args["pattern"], - context_lines=tool_args.get("context_lines", 2), - max_matches=tool_args.get("max_matches", 10), - ) - return self.grep.format_result(matches) + elif tool_name == "grep_text": + matches = self.grep.grep_text( + pattern=tool_args["pattern"], + context_lines=tool_args.get("context_lines", 2), + max_matches=tool_args.get("max_matches", 10), + ) + return self.grep.format_result(matches) - elif tool_name == "get_section_summary": - result = self.summarizer.get_section_summary( - section_name=tool_args["section_name"] - ) - return self.summarizer.format_result(result) + elif tool_name == "get_section_summary": + result = self.summarizer.get_section_summary( + section_name=tool_args["section_name"] + ) + return self.summarizer.format_result(result) - elif tool_name == "list_sections": - sections = self.summarizer.list_sections( - limit=tool_args.get("limit", 20) - ) - return "\n".join(sections) + elif tool_name == "list_sections": + sections = self.summarizer.list_sections( + limit=tool_args.get("limit", 30) + ) + return "Available sections:\n" + "\n".join(sections) + + else: + return f"ERROR: Unknown tool '{tool_name}'. Available: search_index, read_content, grep_text, get_section_summary, list_sections" - else: - raise ValueError(f"Unknown tool: {tool_name}") + except Exception as e: + return f"ERROR executing {tool_name}: {type(e).__name__}: {str(e)}" @staticmethod def get_tool_descriptions() -> str: @@ -296,21 +312,26 @@ def get_tool_descriptions() -> str: return """Available tools: 1. search_index(query: str, top_k: int = 10) - - Semantic search returning chunk metadata (IDs, scores, sources, previews) - - Use to find relevant sections before reading full content + - Returns: Structured list with chunk_id, score, source, preview + - Use chunk_id from results for read_content + - Best for: Finding relevant content by semantic similarity 2. read_content(target_chunk_id: int, relative_start: int = 0, relative_end: int = 0) - - Read chunks with relative offsets from target - - Example: target=100, start=-1, end=2 reads chunks 99-102 + - Returns: Full text of specified chunk range + - relative_start=-1, relative_end=1 reads 3 chunks (before, target, after) + - Use chunk_id from search_index results 3. grep_text(pattern: str, context_lines: int = 2, max_matches: int = 10) - - Regex search across raw markdown - - Use for exact phrases, variable names, specific terms + - Returns: Line numbers and matches with context + - Use for: Exact terms, code snippets, specific phrases + - Pattern is case-insensitive regex 4. get_section_summary(section_name: str) - - Get section content by heading name (partial match) - - Returns truncated content for overview + - Returns: AI-generated summary of section + - Use for: Quick overview before reading full content + - Partial match on section heading -5. list_sections(limit: int = 20) - - List available section headings""" +5. list_sections(limit: int = 30) + - Returns: Available section headings with brief summaries + - Use for: Understanding document structure""" diff --git a/src/main.py b/src/main.py index e7ba807c..ceb14df4 100644 --- a/src/main.py +++ b/src/main.py @@ -299,7 +299,7 @@ def run_agent_chat_session(args: argparse.Namespace, cfg: RAGConfig): sources=sources, embed_model=cfg.embed_model, markdown_path="data/book_with_pages.md", - sections_path="data/extracted_sections.json", + summaries_path="data/section_summaries.json", ) agent_config = AgentConfig( From 21d47b1596d7cd9af0305029d56e23145afe4e0a Mon Sep 17 00:00:00 2001 From: RajShah-1 Date: Fri, 5 Dec 2025 09:17:52 -0500 Subject: [PATCH 07/13] preliminary summary generator --- src/agent/generate_summaries.py | 267 +++++++++++++++++++------------- 1 file changed, 162 insertions(+), 105 deletions(-) diff --git a/src/agent/generate_summaries.py b/src/agent/generate_summaries.py index d073b9f7..05b68eff 100644 --- a/src/agent/generate_summaries.py +++ b/src/agent/generate_summaries.py @@ -1,130 +1,187 @@ """ -Generate dense summaries using sliding window recursive approach. +generate_summaries.py + +Offline summarization pipeline using a large "Thinking" model (Qwen2.5-72B or QwQ-32B). +Uses a recursive rolling window approach to maintain context across long sections. """ import json +import re from pathlib import Path -from typing import List, Dict +from typing import List, Dict, Optional from llama_cpp import Llama -SUMMARY_MODEL_PATH = "models/Qwen3-4B-Instruct-2507-Q5_K_M.gguf" -CHUNK_SIZE = 3000 -OVERLAP = 500 -MAX_SUMMARY_TOKENS = 150 - - -def load_sections(sections_path: str) -> List[Dict]: - with open(sections_path, "r", encoding="utf-8") as f: - return json.load(f) - - -def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: - """Split text into overlapping chunks.""" - if len(text) <= chunk_size: - return [text] - - chunks = [] - start = 0 - while start < len(text): - end = start + chunk_size - chunk = text[start:end] - chunks.append(chunk) - if end >= len(text): - break - start = end - overlap - return chunks - - -def summarize_chunk(text: str, model: Llama) -> str: - """Generate dense summary of a single chunk.""" - prompt = f"""<|im_start|>system -You are a technical summarizer. Extract only key facts and concepts. Be dense and factual. No fluff. +# --- Configuration --- +# Resolve model path relative to project root (parent of src/) +_SCRIPT_DIR = Path(__file__).parent +_PROJECT_ROOT = _SCRIPT_DIR.parent.parent +MODEL_PATH = str(_PROJECT_ROOT / "models" / "Qwen3-30B-A3B-Q6_K.gguf") + +# Context settings +CTX_SIZE = 16384 # Qwen handles 32k+, P6000s have room +MAX_PARA_TOKENS = 1000 # Budget for input paragraph chunks +SUMMARY_BUDGET = 500 # Budget for the output summary +RECURSION_DEPTH = 0 # Track depth to prevent infinite loops + +class ThinkingSummarizer: + def __init__(self, model_path: str, n_ctx: int = 8192): + print(f"Loading model: {model_path}...") + # split_mode=1 (layer split) is usually best for llama.cpp on dual GPUs + # tensor_split=[24, 24] assumes equal VRAM on both P6000s + self.llm = Llama( + model_path=model_path, + n_ctx=n_ctx, + n_gpu_layers=-1, # Offload all layers to GPU + tensor_split=[24, 24], # Split evenly across 2 cards + verbose=False, + n_batch=512 + ) + + def clean_thinking_tokens(self, text: str) -> str: + """Remove ... blocks if using a reasoning model.""" + # Remove thinking blocks + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + # Remove generic clutter + return text.strip() + + def generate_update(self, current_summary: str, new_text: str) -> str: + """ + Updates the running summary with new information. + """ + # Compact prompt for high-intelligence models + prompt = f"""<|im_start|>system +You are an expert technical summarizer. You are maintaining a dense, running summary of a technical document. <|im_end|> <|im_start|>user -Summarize this text. Include only essential information: +Current Summary: +{current_summary or "(None)"} + +New Content to Incorporate: +{new_text} -{text} +Task: Update the 'Current Summary' to include key information from 'New Content'. +Constraints: +1. Keep the total length under {SUMMARY_BUDGET} tokens. +2. Do not lose previous key details. +3. Output ONLY the updated summary. <|im_end|> <|im_start|>assistant """ - result = model.create_completion( - prompt, - max_tokens=MAX_SUMMARY_TOKENS, - temperature=0.0, - stop=["<|im_end|>"], - ) - return result["choices"][0]["text"].strip() - - -def summarize_recursive(text: str, model: Llama) -> str: - """Recursively summarize large text using sliding window.""" - if len(text) <= CHUNK_SIZE: - return summarize_chunk(text, model) - - chunks = chunk_text(text, CHUNK_SIZE, OVERLAP) - chunk_summaries = [] - - for chunk in chunks: - summary = summarize_chunk(chunk, model) - chunk_summaries.append(summary) + output = self.llm.create_completion( + prompt, + max_tokens=SUMMARY_BUDGET + 100, # Buffer + temperature=0.3, # Low temp for factual consistency + stop=["<|im_end|>"] + ) + result = output["choices"][0]["text"] + return self.clean_thinking_tokens(result) + + def summarize_recursive(self, text: str, current_summary: str = "") -> str: + """ + Recursively processes text chunks. + 1. If text + summary fits context, update summary. + 2. If not, split text and recurse. + """ + # Heuristic token estimation (4 chars ~= 1 token) + est_tokens = (len(text) + len(current_summary)) / 3.5 + + # Base Case: Fits in processing window + if est_tokens < (CTX_SIZE - SUMMARY_BUDGET - 1000): + return self.generate_update(current_summary, text) + + # Recursive Case: Split text in half + # Find the nearest sentence boundary in the middle + mid = len(text) // 2 + split_match = re.search(r'[.!?]\s', text[mid:]) + + if split_match: + split_idx = mid + split_match.end() + else: + split_idx = mid # Fallback hard split + + part1 = text[:split_idx] + part2 = text[split_idx:] + + # Process first half + updated_summary = self.summarize_recursive(part1, current_summary) + + # Process second half using result of first + final_summary = self.summarize_recursive(part2, updated_summary) + + return final_summary + + def process_section(self, section_text: str) -> str: + """ + Reads a section paragraph-by-paragraph to build a rolling summary. + """ + paragraphs = section_text.split('\n\n') + running_summary = "" + + # Buffer paragraphs to reduce LLM calls (batching small paras) + buffer = "" + + for i, para in enumerate(paragraphs): + para = para.strip() + if not para: + continue + + buffer += "\n" + para + + # If buffer gets large enough, process it + if len(buffer) > 1500: # ~400 tokens + print(f" > Processing chunk {i+1}/{len(paragraphs)}...") + running_summary = self.summarize_recursive(buffer, running_summary) + buffer = "" + + # Process remaining buffer + if buffer: + running_summary = self.summarize_recursive(buffer, running_summary) + + return running_summary + +def main(): + input_path = _PROJECT_ROOT / "data" / "extracted_sections.json" + output_path = _PROJECT_ROOT / "data" / "section_summaries.json" + + if not input_path.exists(): + print(f"Error: {input_path} not found. Run extraction first.") + return + + # Initialize Summarizer + summarizer = ThinkingSummarizer(MODEL_PATH, n_ctx=CTX_SIZE) + + with open(input_path, "r", encoding="utf-8") as f: + sections = json.load(f) - combined = "\n\n".join(chunk_summaries) - - if len(combined) <= CHUNK_SIZE: - return summarize_chunk(combined, model) - - return summarize_recursive(combined, model) - - -def generate_all_summaries( - sections_path: str, - output_path: str, - model_path: str = SUMMARY_MODEL_PATH, -) -> None: - """Generate summaries for all sections using recursive sliding window.""" - if not Path(model_path).exists(): - raise FileNotFoundError(f"Model not found: {model_path}") - - model = Llama(model_path=model_path, n_ctx=4096, verbose=False) - - sections = load_sections(sections_path) summaries = [] + total = len(sections) for i, section in enumerate(sections, 1): - heading = section.get("heading", "") + heading = section.get("heading", "Untitled") content = section.get("content", "") - - if not content: - summaries.append({ - "heading": heading, - "summary": "", - "content_length": 0, - }) - continue - - print(f"[{i}/{len(sections)}] {heading[:60]}") - - summary = summarize_recursive(content, model) + + print(f"\n[{i}/{total}] Summarizing: {heading} ({len(content)} chars)") + + if not content.strip(): + summary_text = "No content." + else: + try: + summary_text = summarizer.process_section(content) + except Exception as e: + print(f"Error processing section '{heading}': {e}") + summary_text = "Error generating summary." summaries.append({ "heading": heading, - "summary": summary, - "content_length": len(content), + "summary": summary_text, + "content_length": len(content) }) + + # Incremental save + with open(output_path, "w", encoding="utf-8") as f: + json.dump(summaries, f, indent=2, ensure_ascii=False) - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - with open(output_path, "w", encoding="utf-8") as f: - json.dump(summaries, f, indent=2, ensure_ascii=False) - - print(f"\nGenerated {len(summaries)} summaries → {output_path}") - + print(f"\nDone! Summaries saved to {output_path}") if __name__ == "__main__": - generate_all_summaries( - sections_path="data/extracted_sections.json", - output_path="data/section_summaries.json", - model_path=SUMMARY_MODEL_PATH, - ) - + main() \ No newline at end of file From 05027da4a7fc623adede9149a3e38a6c331162fa Mon Sep 17 00:00:00 2001 From: RajShah-1 Date: Fri, 5 Dec 2025 19:24:09 -0500 Subject: [PATCH 08/13] checkpoint summary --- .gitignore | 4 + src/agent/generate_summaries.py | 509 ++++++++++++++++++++++++++------ src/retriever.py | 3 - 3 files changed, 424 insertions(+), 92 deletions(-) diff --git a/.gitignore b/.gitignore index 96a52ddb..5369814a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,9 @@ # RajShah-1 gitignore code-dump.txt +nohup.out +out.log +ideas.md +index/ # --- Python --- __pycache__/ diff --git a/src/agent/generate_summaries.py b/src/agent/generate_summaries.py index 05b68eff..5f2e7bce 100644 --- a/src/agent/generate_summaries.py +++ b/src/agent/generate_summaries.py @@ -1,54 +1,100 @@ """ -generate_summaries.py +Offline pipeline to: +1. Generate dense section summaries using a large "Thinking" model. +2. Build a SQLite navigation index with section/paragraph one-line summaries. -Offline summarization pipeline using a large "Thinking" model (Qwen2.5-72B or QwQ-32B). -Uses a recursive rolling window approach to maintain context across long sections. +Artifacts: +- Input: data/extracted_sections.json +- Output: data/section_summaries.json + data/nav_index.sqlite3 """ import json import re +import sqlite3 from pathlib import Path -from typing import List, Dict, Optional +from typing import List, Dict + from llama_cpp import Llama -# --- Configuration --- -# Resolve model path relative to project root (parent of src/) +# ---------- Paths / Config ---------- + _SCRIPT_DIR = Path(__file__).parent _PROJECT_ROOT = _SCRIPT_DIR.parent.parent + +EXTRACTED_SECTIONS_PATH = _PROJECT_ROOT / "data" / "extracted_sections.json" +SECTION_SUMMARIES_PATH = _PROJECT_ROOT / "data" / "section_summaries.json" +NAV_DB_PATH = _PROJECT_ROOT / "data" / "nav_index.sqlite3" + MODEL_PATH = str(_PROJECT_ROOT / "models" / "Qwen3-30B-A3B-Q6_K.gguf") -# Context settings -CTX_SIZE = 16384 # Qwen handles 32k+, P6000s have room -MAX_PARA_TOKENS = 1000 # Budget for input paragraph chunks -SUMMARY_BUDGET = 500 # Budget for the output summary -RECURSION_DEPTH = 0 # Track depth to prevent infinite loops +CTX_SIZE = 40960 # matches earlier summarizer +SUMMARY_BUDGET = 500 # tokens for dense section summary +ONE_LINE_MAX_TOKENS = 64 # for single-sentence summaries +ONE_LINE_MAX_CHARS = 200 # hard cap on character length + + +# ---------- Thinking Summarizer (integrated) ---------- class ThinkingSummarizer: - def __init__(self, model_path: str, n_ctx: int = 8192): - print(f"Loading model: {model_path}...") - # split_mode=1 (layer split) is usually best for llama.cpp on dual GPUs - # tensor_split=[24, 24] assumes equal VRAM on both P6000s + """ + Large-model summarizer. + + - summarize_recursive(): dense section summary using rolling recursion. + - process_section(): rolling paragraph summarization with buffer. + - one_line(): single-sentence description using a separate prompt. + """ + + def __init__(self, model_path: str, n_ctx: int): + print(f"[summ] Loading model: {model_path}") self.llm = Llama( model_path=model_path, n_ctx=n_ctx, - n_gpu_layers=-1, # Offload all layers to GPU - tensor_split=[24, 24], # Split evenly across 2 cards + n_gpu_layers=-1, + tensor_split=[24, 24], # adjust for your GPUs verbose=False, - n_batch=512 + n_batch=512, ) + # ----- shared cleanup for thinking tags ----- + def clean_thinking_tokens(self, text: str) -> str: - """Remove ... blocks if using a reasoning model.""" - # Remove thinking blocks - text = re.sub(r'.*?', '', text, flags=re.DOTALL) - # Remove generic clutter - return text.strip() + """Strip / blocks while keeping the final answer.""" + if not text: + return "" + + flags = re.DOTALL | re.IGNORECASE + + # If there is a closing , keep everything AFTER the last one + m = re.search(r'\s*(.*)$', text, flags=flags) + if m: + text = m.group(1) + + # Same idea for and + for tag in ("thinking", "reasoning"): + m = re.search(rf'\s*(.*)$', text, flags=flags) + if m: + text = m.group(1) + + # Remove remaining bare tags + text = re.sub(r']*>', '', text, flags=flags) + text = re.sub(r']*>', '', text, flags=flags) + text = re.sub(r']*>', '', text, flags=flags) + + text = re.sub(r'\n\s*\n+', '\n\n', text) + cleaned = text.strip() + + # If we somehow end up empty, treat as an error instead of hiding it + if not cleaned: + raise RuntimeError("LLM returned empty text after cleaning thinking tokens.") + return cleaned + + # ----- dense rolling summary over long text ----- def generate_update(self, current_summary: str, new_text: str) -> str: """ - Updates the running summary with new information. + Update the running summary with new information, staying within SUMMARY_BUDGET. """ - # Compact prompt for high-intelligence models prompt = f"""<|im_start|>system You are an expert technical summarizer. You are maintaining a dense, running summary of a technical document. <|im_end|> @@ -63,125 +109,410 @@ def generate_update(self, current_summary: str, new_text: str) -> str: Constraints: 1. Keep the total length under {SUMMARY_BUDGET} tokens. 2. Do not lose previous key details. -3. Output ONLY the updated summary. +3. Output ONLY the updated summary. Do NOT include any thinking tags, reasoning blocks, or meta-commentary. <|im_end|> <|im_start|>assistant """ output = self.llm.create_completion( prompt, - max_tokens=SUMMARY_BUDGET + 100, # Buffer - temperature=0.3, # Low temp for factual consistency - stop=["<|im_end|>"] + max_tokens=SUMMARY_BUDGET + 100, + temperature=0.3, + stop=["<|im_end|>"], ) - result = output["choices"][0]["text"] - return self.clean_thinking_tokens(result) + raw = output["choices"][0]["text"] + return self.clean_thinking_tokens(raw) def summarize_recursive(self, text: str, current_summary: str = "") -> str: """ - Recursively processes text chunks. + Recursively process text chunks. 1. If text + summary fits context, update summary. 2. If not, split text and recurse. """ - # Heuristic token estimation (4 chars ~= 1 token) est_tokens = (len(text) + len(current_summary)) / 3.5 - # Base Case: Fits in processing window if est_tokens < (CTX_SIZE - SUMMARY_BUDGET - 1000): return self.generate_update(current_summary, text) - - # Recursive Case: Split text in half - # Find the nearest sentence boundary in the middle + + # Split roughly in half on a sentence boundary mid = len(text) // 2 split_match = re.search(r'[.!?]\s', text[mid:]) - if split_match: split_idx = mid + split_match.end() else: - split_idx = mid # Fallback hard split + split_idx = mid part1 = text[:split_idx] part2 = text[split_idx:] - # Process first half updated_summary = self.summarize_recursive(part1, current_summary) - - # Process second half using result of first final_summary = self.summarize_recursive(part2, updated_summary) - return final_summary def process_section(self, section_text: str) -> str: """ Reads a section paragraph-by-paragraph to build a rolling summary. """ - paragraphs = section_text.split('\n\n') + paragraphs = section_text.split("\n\n") running_summary = "" - - # Buffer paragraphs to reduce LLM calls (batching small paras) buffer = "" - + for i, para in enumerate(paragraphs): para = para.strip() if not para: continue buffer += "\n" + para - - # If buffer gets large enough, process it - if len(buffer) > 1500: # ~400 tokens - print(f" > Processing chunk {i+1}/{len(paragraphs)}...") + + if len(buffer) > 8000: # ~2200 tokens + print(f" [summ] Processing chunk {i + 1}/{len(paragraphs)}...") running_summary = self.summarize_recursive(buffer, running_summary) buffer = "" - # Process remaining buffer if buffer: running_summary = self.summarize_recursive(buffer, running_summary) - return running_summary + return self.clean_thinking_tokens(running_summary) -def main(): - input_path = _PROJECT_ROOT / "data" / "extracted_sections.json" - output_path = _PROJECT_ROOT / "data" / "section_summaries.json" - - if not input_path.exists(): - print(f"Error: {input_path} not found. Run extraction first.") - return + # ----- one-line summary ----- - # Initialize Summarizer + def one_line(self, text: str) -> str: + """ + Ask the model for a single-sentence one-liner. + + No fallback. If the call fails or yields nothing, an exception is raised. + """ + if not text.strip(): + raise ValueError("Cannot create one-line summary from empty text.") + + prompt = f"""<|im_start|>system +You write concise, single-sentence descriptions of technical text. +Output exactly one sentence, under {ONE_LINE_MAX_CHARS} characters. +No bullet points, no lists, no markdown, no explanations. +<|im_end|> +<|im_start|>user +Text: +{text} +<|im_end|> +<|im_start|>assistant +""" + output = self.llm.create_completion( + prompt, + max_tokens=ONE_LINE_MAX_TOKENS, + temperature=0.3, + stop=["<|im_end|>"], + ) + raw = output["choices"][0]["text"] + cleaned = self.clean_thinking_tokens(raw) + cleaned = " ".join(cleaned.split()) + + if len(cleaned) > ONE_LINE_MAX_CHARS: + cleaned = cleaned[: ONE_LINE_MAX_CHARS - 3] + "..." + + if not cleaned: + raise RuntimeError("LLM returned empty one-line summary.") + return cleaned + + +# ---------- JSON Helpers ---------- + +def load_sections() -> List[Dict]: + """Load raw sections with heading + content.""" + with open(EXTRACTED_SECTIONS_PATH, "r", encoding="utf-8") as f: + return json.load(f) + + +def split_paragraphs(text: str) -> List[str]: + """Simple paragraph splitter using double newlines.""" + paras = [p.strip() for p in text.split("\n\n")] + return [p for p in paras if p] + + +# ---------- SQLite Schema ---------- + +def init_db(conn: sqlite3.Connection) -> None: + """Create tables, dropping any existing ones.""" + cur = conn.cursor() + cur.executescript( + """ + DROP TABLE IF EXISTS paragraphs; + DROP TABLE IF EXISTS sections; + + CREATE TABLE sections ( + id INTEGER PRIMARY KEY, + heading TEXT NOT NULL, + section_summary TEXT, + one_line_summary TEXT, + prev_section_id INTEGER, + next_section_id INTEGER, + num_paragraphs INTEGER, + content_length INTEGER + ); + + CREATE TABLE paragraphs ( + id INTEGER PRIMARY KEY, + section_id INTEGER NOT NULL, + para_index INTEGER NOT NULL, + one_line_summary TEXT, + raw_text TEXT, + FOREIGN KEY(section_id) REFERENCES sections(id) + ); + + CREATE INDEX idx_paragraphs_section + ON paragraphs(section_id, para_index); + """ + ) + conn.commit() + +# ------------------ Reprocessing Single Section ------------------ + +def reprocess_single_section(section_id: int) -> None: + """ + Re-run summarization + navigation indexing for a single section id. + + Steps: + 1) Recompute dense summary for the section. + 2) Recompute one-line summary for the section. + 3) Recompute one-line summaries for all paragraphs in that section. + 4) Update nav_index.sqlite3 (sections + paragraphs for this section only). + 5) Update section_summaries.json for this section only. + + Assumes: + - extracted_sections.json is the source of truth. + - nav_index.sqlite3 and section_summaries.json already exist and cover all sections. + """ + # --- Load base sections (source of truth) --- + if not EXTRACTED_SECTIONS_PATH.exists(): + raise FileNotFoundError( + f"{EXTRACTED_SECTIONS_PATH} not found. Run extraction first." + ) + sections = load_sections() + total = len(sections) + + if section_id < 1 or section_id > total: + raise ValueError(f"section_id {section_id} out of range (1..{total})") + + sec = sections[section_id - 1] + heading = sec.get("heading", "Untitled") + content = sec.get("content", "") or "" + content_length = len(content) + + print(f"\n[reprocess] Section {section_id}/{total}: {heading} ({content_length} chars)") + + # --- Initialize summarizer (big model) --- summarizer = ThinkingSummarizer(MODEL_PATH, n_ctx=CTX_SIZE) - with open(input_path, "r", encoding="utf-8") as f: - sections = json.load(f) + if not content.strip(): + raise ValueError(f"Section {section_id} has no content; nothing to reprocess.") + + # --- 1) Dense section summary --- + section_summary = summarizer.process_section(content) + + # --- 2) One-line section summary (from dense summary) --- + one_line_section = summarizer.one_line(section_summary) + + # --- 3) Paragraphs + their one-liners --- + paras = split_paragraphs(content) + num_paras = len(paras) + + # prev / next IDs are purely sequential, as in the original pipeline + prev_id = section_id - 1 if section_id > 1 else None + next_id = section_id + 1 if section_id < total else None + + # --- 4) Update SQLite nav index (no schema reset) --- + if not NAV_DB_PATH.exists(): + raise FileNotFoundError( + f"{NAV_DB_PATH} not found. Run the full nav-index pipeline once first." + ) + + conn = sqlite3.connect(NAV_DB_PATH) + cur = conn.cursor() + + # Make sure the expected tables exist + cur.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='sections';" + ) + if cur.fetchone() is None: + conn.close() + raise RuntimeError("SQLite DB missing 'sections' table.") + + cur.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='paragraphs';" + ) + if cur.fetchone() is None: + conn.close() + raise RuntimeError("SQLite DB missing 'paragraphs' table.") + + # Delete existing rows for this section + cur.execute("DELETE FROM paragraphs WHERE section_id = ?", (section_id,)) + cur.execute("DELETE FROM sections WHERE id = ?", (section_id,)) + + # Insert updated section row + cur.execute( + """ + INSERT INTO sections ( + id, heading, section_summary, one_line_summary, + prev_section_id, next_section_id, num_paragraphs, content_length + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + section_id, + heading, + section_summary, + one_line_section, + prev_id, + next_id, + num_paras, + content_length, + ), + ) + + # Insert updated paragraphs + for p_idx, para_text in enumerate(paras): + para_one_line = summarizer.one_line(para_text) + cur.execute( + """ + INSERT INTO paragraphs ( + section_id, para_index, one_line_summary, raw_text + ) VALUES (?, ?, ?, ?) + """, + (section_id, p_idx, para_one_line, para_text), + ) + + conn.commit() + conn.close() + + # --- 5) Update section_summaries.json entry for this section --- + if not SECTION_SUMMARIES_PATH.exists(): + raise FileNotFoundError( + f"{SECTION_SUMMARIES_PATH} not found. Run the full summary pipeline once first." + ) + + with open(SECTION_SUMMARIES_PATH, "r", encoding="utf-8") as f: + summaries_data = json.load(f) + + # Basic sanity: expect at least 'total' entries and matching index + if not isinstance(summaries_data, list): + raise RuntimeError("section_summaries.json is not a list.") + + if len(summaries_data) < total: + raise RuntimeError( + f"section_summaries.json only has {len(summaries_data)} entries, " + f"but extracted_sections has {total}." + ) + + summaries_data[section_id - 1] = { + "heading": heading, + "summary": section_summary, + "content_length": content_length, + } + + with open(SECTION_SUMMARIES_PATH, "w", encoding="utf-8") as f: + json.dump(summaries_data, f, indent=2, ensure_ascii=False) + + print(f"[reprocess] Updated section {section_id} in nav_index.sqlite3 and section_summaries.json") + + +# ---------- Main Pipeline ---------- + +def main() -> None: + if not EXTRACTED_SECTIONS_PATH.exists(): + raise FileNotFoundError( + f"{EXTRACTED_SECTIONS_PATH} not found. Run extraction first." + ) - summaries = [] + NAV_DB_PATH.parent.mkdir(parents=True, exist_ok=True) + + sections = load_sections() total = len(sections) - for i, section in enumerate(sections, 1): - heading = section.get("heading", "Untitled") - content = section.get("content", "") - - print(f"\n[{i}/{total}] Summarizing: {heading} ({len(content)} chars)") - + summarizer = ThinkingSummarizer(MODEL_PATH, n_ctx=CTX_SIZE) + conn = sqlite3.connect(NAV_DB_PATH) + init_db(conn) + cur = conn.cursor() + + section_summaries_output: List[Dict] = [] + + for idx, sec in enumerate(sections, start=1): + heading = sec.get("heading", "Untitled") + content = sec.get("content", "") or "" + content_length = len(content) + + print(f"\n[{idx}/{total}] Section: {heading} ({content_length} chars)") + if not content.strip(): - summary_text = "No content." + section_summary = "" + one_line_section = "" + num_paras = 0 else: - try: - summary_text = summarizer.process_section(content) - except Exception as e: - print(f"Error processing section '{heading}': {e}") - summary_text = "Error generating summary." - - summaries.append({ - "heading": heading, - "summary": summary_text, - "content_length": len(content) - }) - - # Incremental save - with open(output_path, "w", encoding="utf-8") as f: - json.dump(summaries, f, indent=2, ensure_ascii=False) - - print(f"\nDone! Summaries saved to {output_path}") + # 1) Dense section summary (rolling, recursive) + section_summary = summarizer.process_section(content) + + # 2) One-line section summary from the dense summary + one_line_section = summarizer.one_line(section_summary) + + # 3) Paragraphs + their one-liners + paras = split_paragraphs(content) + num_paras = len(paras) + + # prev / next IDs are purely sequential + prev_id = idx - 1 if idx > 1 else None + next_id = idx + 1 if idx < total else None + + # Insert section row + cur.execute( + """ + INSERT INTO sections ( + id, heading, section_summary, one_line_summary, + prev_section_id, next_section_id, num_paragraphs, content_length + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + idx, + heading, + section_summary, + one_line_section, + prev_id, + next_id, + num_paras, + content_length, + ), + ) + + # Insert paragraphs + if content.strip(): + paras = split_paragraphs(content) + for p_idx, para_text in enumerate(paras): + para_one_line = summarizer.one_line(para_text) + cur.execute( + """ + INSERT INTO paragraphs ( + section_id, para_index, one_line_summary, raw_text + ) VALUES (?, ?, ?, ?) + """, + (idx, p_idx, para_one_line, para_text), + ) + + conn.commit() + + # Collect for section_summaries.json + section_summaries_output.append( + { + "heading": heading, + "summary": section_summary, + "content_length": content_length, + } + ) + + # Incremental save of summaries file + with open(SECTION_SUMMARIES_PATH, "w", encoding="utf-8") as f: + json.dump(section_summaries_output, f, indent=2, ensure_ascii=False) + + conn.close() + + print(f"\n[done] Section summaries written to {SECTION_SUMMARIES_PATH}") + print(f"[done] Navigation index written to {NAV_DB_PATH}") + if __name__ == "__main__": - main() \ No newline at end of file + # reprocess_single_section(42) + main() diff --git a/src/retriever.py b/src/retriever.py index 01105916..e314b76d 100644 --- a/src/retriever.py +++ b/src/retriever.py @@ -213,8 +213,6 @@ def __init__(self, extracted_index_path: os.PathLike, page_to_chunk_map_path: os self.phrase_to_pages = {} self.token_to_phrases = {} - print(f"IndexKeywordRetriever initialized with {len(self.phrase_to_pages)} phrases and {len(self.token_to_phrases)} tokens") - if os.path.exists(page_to_chunk_map_path): with open(page_to_chunk_map_path, 'r') as f: self.page_to_chunk_map = json.load(f) @@ -235,7 +233,6 @@ def get_scores(self, query: str, pool_size: int, chunks: List[str]) -> Dict[int, # Get all phrases containing this keyword token matching_phrases = self.token_to_phrases[keyword] - print(f'Found {len(matching_phrases)} matching phrases for keyword: {keyword}') for phrase in matching_phrases: page_numbers = self.phrase_to_pages[phrase] From 81a3fad36da53c6b3b1df0f47e33504f963bdc4c Mon Sep 17 00:00:00 2001 From: RajShah-1 Date: Fri, 5 Dec 2025 23:45:29 -0500 Subject: [PATCH 09/13] change default n_ctx to a higher limit --- src/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generator.py b/src/generator.py index 7fcfea38..e6f933fb 100644 --- a/src/generator.py +++ b/src/generator.py @@ -113,7 +113,7 @@ def format_prompt(chunks, query, max_chunk_chars=400, system_prompt_mode="tutor" _LLM_CACHE = {} -def get_llama_model(model_path: str, n_ctx: int = 4096): +def get_llama_model(model_path: str, n_ctx: int = 16384): if model_path not in _LLM_CACHE: _LLM_CACHE[model_path] = Llama(model_path=model_path, n_ctx=n_ctx, From 9e1ecca83788d5948be72b5fb64f86dd03eea7b0 Mon Sep 17 00:00:00 2001 From: RajShah-1 Date: Fri, 5 Dec 2025 23:56:19 -0500 Subject: [PATCH 10/13] integrate agent-mode in test framework --- tests/conftest.py | 27 +++++++++ tests/test_benchmarks.py | 127 ++++++++++++++++++++++++++------------- 2 files changed, 113 insertions(+), 41 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 57a98140..86839a59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -89,6 +89,18 @@ def pytest_addoption(parser): default=None, help="System prompt mode (overrides config)" ) + group.addoption( + "--use-agent", + action="store_true", + default=None, + help="Enable agent mode (overrides config)" + ) + group.addoption( + "--no-agent", + action="store_true", + default=None, + help="Disable agent mode (overrides config)" + ) # === Testing Options === group.addoption( @@ -176,8 +188,23 @@ def config(pytestconfig): # Query Enhancement (HyDE) "use_hyde": cfg.get("use_hyde", False), "hyde_max_tokens": cfg.get("hyde_max_tokens", 100), + + # Agent Mode + "agent_reasoning_limit": cfg.get("agent_reasoning_limit", 5), + "agent_tool_limit": cfg.get("agent_tool_limit", 20), } + # Handle agent mode + use_agent_cli = pytestconfig.getoption("--use-agent") + no_agent_cli = pytestconfig.getoption("--no-agent") + + if use_agent_cli: + merged_config["use_agent"] = True + elif no_agent_cli: + merged_config["use_agent"] = False + else: + merged_config["use_agent"] = cfg.get("use_agent", False) + # Handle enable/disable chunks disable_chunks_cli = pytestconfig.getoption("--disable-chunks") diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index e3eefc0b..4abbd55c 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -58,6 +58,10 @@ def print_test_config(config, scorer): print(f" Chunks Enabled: {not config['disable_chunks']}") print(f" Golden Chunks: {config['use_golden_chunks']}") print(f" HyDE Enabled: {config.get('use_hyde', False)}") + print(f" Agent Mode: {config.get('use_agent', False)}") + if config.get('use_agent', False): + print(f" • Reasoning Limit: {config.get('agent_reasoning_limit', 5)}") + print(f" • Tool Limit: {config.get('agent_tool_limit', 20)}") print(f" Output Mode: {config['output_mode']}") print(f" Metrics: {', '.join(active_metrics)}") print(f"{'='*60}\n") @@ -147,6 +151,9 @@ def run_benchmark(benchmark, config, results_dir, scorer): "system_prompt_mode": config["system_prompt_mode"], "disable_chunks": config["disable_chunks"], "use_golden_chunks": config["use_golden_chunks"], + "use_agent": config.get("use_agent", False), + "agent_reasoning_limit": config.get("agent_reasoning_limit", 5), + "agent_tool_limit": config.get("agent_tool_limit", 20), } } @@ -173,7 +180,6 @@ def get_tokensmith_answer(question, config, golden_chunks=None): Returns: tuple: (Generated answer, chunks_info list, hyde_query) """ - from src.main import get_answer from src.instrumentation.logging import init_logger, get_logger from src.config import RAGConfig from src.retriever import BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts @@ -216,7 +222,9 @@ def get_tokensmith_answer(question, config, golden_chunks=None): ) # Print status - if golden_chunks and config["use_golden_chunks"]: + if config.get("use_agent", False): + print(f" 🤖 Agent mode enabled") + elif golden_chunks and config["use_golden_chunks"]: print(f" 📌 Using {len(golden_chunks)} golden chunks") elif config["disable_chunks"]: print(f" 📭 No chunks (baseline mode)") @@ -235,47 +243,84 @@ def get_tokensmith_answer(question, config, golden_chunks=None): index_prefix=config["index_prefix"] ) - retrievers = [ - FAISSRetriever(faiss_index, cfg.embed_model), - BM25Retriever(bm25_index) - ] - - # Add index keyword retriever if weight > 0 - if cfg.ranker_weights.get("index_keywords", 0) > 0: - retrievers.append( - IndexKeywordRetriever(cfg.extracted_index_path, cfg.page_to_chunk_map_path) + # Check if agent mode is enabled + if config.get("use_agent", False): + # Use agent orchestrator path + from src.agent.tools import AgentToolkit + from src.agent.orchestrator import AgentOrchestrator, AgentConfig + + toolkit = AgentToolkit( + faiss_index=faiss_index, + chunks=chunks, + sources=sources, + embed_model=cfg.embed_model, + markdown_path="data/book_with_pages.md", + summaries_path="data/section_summaries.json", ) - - ranker = EnsembleRanker( - ensemble_method=cfg.ensemble_method, - weights=cfg.ranker_weights, - rrf_k=int(cfg.rrf_k) - ) - - # Package artifacts for reuse - artifacts = { - "chunks": chunks, - "sources": sources, - "retrievers": retrievers, - "ranker": ranker - } - - result = get_answer( - question=question, - cfg=cfg, - args=args, - logger=logger, - artifacts=artifacts, - console=None, - golden_chunks=golden_chunks, - is_test_mode=True - ) - - # Handle return value (answer, chunks_info, hyde_query) or just answer - if isinstance(result, tuple): - generated, chunks_info, hyde_query = result + + agent_config = AgentConfig( + reasoning_limit=cfg.agent_reasoning_limit, + tool_limit=cfg.agent_tool_limit, + max_generation_tokens=cfg.max_gen_tokens, + ) + + orchestrator = AgentOrchestrator( + toolkit=toolkit, + model_path=args.model_path or cfg.model_path, + config=agent_config, + ) + + # Run agent and get answer + result = orchestrator.run(question) + generated = result["answer"] + chunks_info = None # Agent mode doesn't provide chunks_info in the same format + hyde_query = None + else: - generated, chunks_info, hyde_query = result, None, None + # Use standard get_answer path + from src.main import get_answer + + retrievers = [ + FAISSRetriever(faiss_index, cfg.embed_model), + BM25Retriever(bm25_index) + ] + + # Add index keyword retriever if weight > 0 + if cfg.ranker_weights.get("index_keywords", 0) > 0: + retrievers.append( + IndexKeywordRetriever(cfg.extracted_index_path, cfg.page_to_chunk_map_path) + ) + + ranker = EnsembleRanker( + ensemble_method=cfg.ensemble_method, + weights=cfg.ranker_weights, + rrf_k=int(cfg.rrf_k) + ) + + # Package artifacts for reuse + artifacts = { + "chunks": chunks, + "sources": sources, + "retrievers": retrievers, + "ranker": ranker + } + + result = get_answer( + question=question, + cfg=cfg, + args=args, + logger=logger, + artifacts=artifacts, + console=None, + golden_chunks=golden_chunks, + is_test_mode=True + ) + + # Handle return value (answer, chunks_info, hyde_query) or just answer + if isinstance(result, tuple): + generated, chunks_info, hyde_query = result + else: + generated, chunks_info, hyde_query = result, None, None # Clean answer - extract up to end token if present generated = clean_answer(generated) From 8cd2e9fa09398e2541681bad32459f1e01c9607b Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Mon, 19 Jan 2026 15:50:09 -0500 Subject: [PATCH 11/13] bug fix: regressions for test-suite --- config/config.yaml | 2 +- tests/benchmarks.yaml | 70 -------------------------------- tests/metrics/async_llm_judge.py | 19 +++++++-- tests/metrics/chunk_retrieval.py | 4 ++ tests/test_benchmarks.py | 20 +++++---- 5 files changed, 31 insertions(+), 84 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 578ea5a5..41c20466 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -15,6 +15,6 @@ use_indexed_chunks: false rerank_mode: "cross_encoder" # Agent mode settings -use_agent: false +use_agent: true agent_reasoning_limit: 5 agent_tool_limit: 20 diff --git a/tests/benchmarks.yaml b/tests/benchmarks.yaml index ae106cb8..152dd291 100644 --- a/tests/benchmarks.yaml +++ b/tests/benchmarks.yaml @@ -1,74 +1,4 @@ benchmarks: - - id: "aggregation_grouping" - question: "How do aggregation with grouping work?" - expected_answer: "Aggregation partitions tuples by grouping attributes and applies functions like sum, avg, min, max to each group, producing one result per group." - keywords: ["aggregation", "grouping", "generalized projection", "ignore nulls", "attribute renaming", "duplicates"] - similarity_threshold: 0.8 - ideal_retrieved_chunks: [1039, 1040, 1496, 1497, 714] - - - id: "acid_properties" - question: "What are the ACID properties of transactions?" - expected_answer: "Atomicity ensures a transaction's actions are all-or-nothing, enforced by abort/rollback and recovery that can undo partial effects; consistency requires each transaction to preserve database integrity when run alone and relies on the scheduler to admit serializable, recoverable, and preferably cascadeless schedules; isolation makes concurrent executions equivalent to some serial order, commonly achieved with two-phase locking variants that prevent reads of uncommitted data; durability guarantees committed effects persist across crashes via logging to stable storage and redo on restart;" - keywords: ["atomicity", "consistency", "isolation", "durability", "recoverable", "checkpoints", "two-phase locking", "two-phase commit"] - similarity_threshold: 0.82 - ideal_retrieved_chunks: [1143, 1142, 1145, 1146, 1148] - - - id: "bptree" - question: "How does a B+ tree index organize keys and support search, insert, and delete, and why is it preferred over binary trees for disk-based access" - expected_answer: "B+-trees match node size to a disk page, giving very high fan-out and a shallow, height-balanced tree, so searches/updates require few page I/Os." - keywords: ["fan-out", "leaf linkage", "merge", "balanced height", "reduced height"] - similarity_threshold: 0.78 - ideal_retrieved_chunks: [908, 909, 940, 937, 938] - - - id: "fd_normalization" - question: "What are functional dependencies?" - expected_answer: "A functional dependency X -> Y asserts that tuples agreeing on X must agree on Y." - keywords: ["common key", "lossless join", "dependency preservation", "normalization", "superkey"] - similarity_threshold: 0.7 - ideal_retrieved_chunks: [438, 439, 463, 451, 484] - - - id: "sql_isolation" - question: "What isolation guarantees does SQL provide by default?" - expected_answer: "Serializable. In the SQL standard, the default isolation level is Serializable, which guarantees that the outcome of concurrently executing transactions is equivalent to some serial (one-at-a-time) order of those transactions—thereby preventing dirty reads, nonrepeatable reads, and phantoms." - keywords: ["serializable", "read committed", "repeatable read", "dirty read", "nonrepeatable read", "phantom", "two-phase locking", "predicate locking"] - similarity_threshold: 0.7 - ideal_retrieved_chunks: [1173, 1174, 1142, 1143, 1172] - - - id: "primary_foreign_keys" - question: "Explain primary keys and foreign keys" - expected_answer: "A primary key is a set of one or more attributes that uniquely identifies each tuple in a relation, chosen from candidate keys which are minimal superkeys; primary key attributes are underlined in schema diagrams and cannot have null values. A foreign key is a set of attributes in one relation (the referencing relation) that references the primary key of another relation (the referenced relation), establishing a referential integrity constraint that requires values in the foreign key to match values in the referenced primary key, thereby linking related data across tables." - keywords: ["primary key", "foreign key", "unique identifier", "referential integrity", "candidate key", "superkey"] - similarity_threshold: 0.72 - ideal_retrieved_chunks: [90, 91, 119, 93, 94] - - - id: "database_schema" - question: "What is a database schema" - expected_answer: "A database schema is the overall logical design and structure of the database, analogous to variable declarations in a program, defining the relations, their attributes, data types, and constraints including primary keys and foreign keys. The schema remains relatively stable over time, while a database instance represents the actual collection of data stored at a particular moment, with values that change as information is inserted, deleted, or modified." - keywords: ["database schema", "logical design", "structure", "database instance", "relations", "attributes", "constraints"] - similarity_threshold: 0.70 - ideal_retrieved_chunks: [49, 50, 51, 250, 60] - - - id: "book_authors" - question: "Tell me about the authors of the book" - expected_answer: "The authors of Database System Concepts Seventh Edition are Abraham Silberschatz, a Professor at Yale University and former Bell Labs vice president who is an ACM and IEEE fellow; Henry F. Korth, a Professor at Lehigh University who previously worked at Bell Labs and is also an ACM and IEEE fellow; and S. Sudarshan, the Subrao M. Nilekani Chair Professor at the Indian Institute of Technology Bombay who received his Ph.D. from the University of Wisconsin and is an ACM fellow, with research focusing on query processing and optimization." - keywords: ["Abraham Silberschatz", "Henry Korth", "Sudarshan", "Yale", "Lehigh", "IIT Bombay", "Bell Labs"] - similarity_threshold: 0.65 - ideal_retrieved_chunks: [1, 2, 0, 3, 22] - - - id: "aries_atomicity" - question: "How does the recovery manager use ARIES to ensure atomicity" - expected_answer: "The ARIES recovery algorithm ensures atomicity by maintaining a write-ahead log where all updates are recorded before being applied to the database, with each log record containing transaction ID, data item, old value, and new value. During normal operation, log records are written to stable storage before the transaction commits; if a transaction aborts or the system crashes, the recovery manager uses the log to undo uncommitted transactions by applying the old values in reverse order, ensuring that partial effects of incomplete transactions are completely rolled back and the all-or-nothing property of atomicity is preserved." - keywords: ["ARIES", "write-ahead log", "log records", "undo", "rollback", "stable storage", "atomicity", "recovery"] - similarity_threshold: 0.75 - ideal_retrieved_chunks: [1355, 1356, 1358, 1353, 1359] - - - id: "oltp_vs_analytics" - question: "Contrast the goals of Online Transaction Processing and data analytics" - expected_answer: "Online Transaction Processing (OLTP) supports a large number of concurrent users performing small, fast transactions that retrieve and update relatively small amounts of data with requirements for high throughput, low latency, and immediate consistency, typically using normalized schemas optimized for transactional integrity. Data analytics, in contrast, processes large volumes of historical data to draw conclusions and infer patterns for business intelligence and decision support, involving complex queries that scan and aggregate data across many records, often using denormalized schemas like star schemas in data warehouses optimized for read-heavy analytical workloads rather than transactional updates." - keywords: ["OLTP", "online transaction processing", "data analytics", "business intelligence", "decision support", "throughput", "data warehouse", "transactional", "analytical"] - similarity_threshold: 0.73 - ideal_retrieved_chunks: [33, 34, 738, 739, 741] - - id: "lossy_decomposition" question: "Show me what happens during a lossy decomposition" expected_answer: "A lossy decomposition occurs when a relation R is decomposed into smaller relations R1 and R2 such that joining them back together produces spurious tuples not present in the original relation, resulting in loss of information about which attribute combinations actually existed. This happens when the intersection of R1 and R2 does not form a superkey for either relation, violating the lossless-join condition; the natural join of the decomposed relations generates extra tuples from invalid combinations, making it impossible to reconstruct the original data accurately, which is why database design insists that all decompositions must be lossless." diff --git a/tests/metrics/async_llm_judge.py b/tests/metrics/async_llm_judge.py index 4a819ef4..f72fca20 100644 --- a/tests/metrics/async_llm_judge.py +++ b/tests/metrics/async_llm_judge.py @@ -2,6 +2,7 @@ from pathlib import Path from datetime import datetime import json +import os import threading import time from concurrent.futures import ThreadPoolExecutor @@ -11,6 +12,12 @@ from tests.metrics.base import MetricBase from tests.metrics.llm_judge import GradingResult + +def _has_api_key() -> bool: + """Check if Google API key is available.""" + return bool(os.environ.get("GOOGLE_API_KEY") or os.environ.get("GOOGLE_GENAI_API_KEY")) + + # Shared state for async grading _results_lock = threading.Lock() _grading_results: Dict[str, Dict] = {} @@ -39,10 +46,14 @@ def __init__(self, log_dir: Optional[Path] = None): self.log_dir = Path(log_dir) self.log_dir.mkdir(parents=True, exist_ok=True) self.results_file = self.log_dir / "async_llm_results.json" + self._available = _has_api_key() - # Initialize client once - if not _initialized: - _lazy_init() + # Initialize client once only if API key is available + if self._available and not _initialized: + try: + _lazy_init() + except Exception: + self._available = False @property def name(self) -> str: @@ -53,7 +64,7 @@ def weight(self) -> float: return 0.0 def is_available(self) -> bool: - return True + return self._available def calculate(self, answer: str, expected: str, keywords: Optional[List[str]] = None) -> float: """ diff --git a/tests/metrics/chunk_retrieval.py b/tests/metrics/chunk_retrieval.py index 8ae7f53c..240dbe4f 100644 --- a/tests/metrics/chunk_retrieval.py +++ b/tests/metrics/chunk_retrieval.py @@ -16,6 +16,10 @@ def name(self) -> str: def calculate(self, ideal_retrieved_chunks: List[int], retrieved_chunks) -> float: + # Handle None/empty inputs (e.g. agent mode doesn't provide chunk info) + if not retrieved_chunks or not ideal_retrieved_chunks: + return 0.0 + print("ideal_retrieved_chunks: ", ideal_retrieved_chunks) print("retrieved_chunks: ", [chunk["chunk_id"] for chunk in retrieved_chunks]) found_chunks = [chunk["chunk_id"] for chunk in retrieved_chunks if chunk["chunk_id"] in ideal_retrieved_chunks] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 4abbd55c..c84b3735 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -194,19 +194,21 @@ def get_tokensmith_answer(question, config, golden_chunks=None): ) # Create RAGConfig from our test config + # Note: chunk_config is computed in __post_init__ from chunk_mode/chunk_size/chunk_overlap cfg = RAGConfig( - chunk_config=RAGConfig.get_chunk_config(config), + chunk_mode=config.get("chunk_mode", "recursive_sections"), + chunk_size=config.get("recursive_chunk_size", 2000), + chunk_overlap=config.get("recursive_overlap", 200), top_k=config.get("top_k", 5), - pool_size=config.get("pool_size", 60), + num_candidates=config.get("pool_size", 60), embed_model=config.get("embed_model"), - ensemble_method=config.get("retrieval_method", "rrf"), - rrf_k=60, + ensemble_method=config.get("ensemble_method", "rrf"), + rrf_k=config.get("rrf_k", 60), ranker_weights=config.get("ranker_weights", {"faiss": 0.6, "bm25": 0.4}), - rerank_mode=config.get("rerank_mode", "none"), - seg_filter=config.get("seg_filter", None), + rerank_mode=config.get("rerank_mode", ""), + gen_model=config.get("model_path"), system_prompt_mode=config.get("system_prompt_mode", "baseline"), max_gen_tokens=config.get("max_gen_tokens", 400), - model_path=config.get("model_path"), disable_chunks=config.get("disable_chunks", False), use_golden_chunks=config.get("use_golden_chunks", False), output_mode=config.get("output_mode", "html"), @@ -237,8 +239,8 @@ def get_tokensmith_answer(question, config, golden_chunks=None): logger = get_logger() # Run the query through the main pipeline - artifacts_dir = cfg.make_artifacts_directory() - faiss_index, bm25_index, chunks, sources = load_artifacts( + artifacts_dir = cfg.get_artifacts_directory() + faiss_index, bm25_index, chunks, sources, metadata = load_artifacts( artifacts_dir=artifacts_dir, index_prefix=config["index_prefix"] ) From 813beb6ecae3be5fa59261bb474f382c14935903 Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Wed, 28 Jan 2026 21:51:21 -0500 Subject: [PATCH 12/13] refactor to multi-file --- src/agent/__init__.py | 5 +- src/agent/context.py | 117 +++++++ src/agent/context_manager.py | 50 --- src/agent/generate_summaries.py | 549 +++----------------------------- src/agent/llm.py | 19 ++ src/agent/logger.py | 83 +---- src/agent/orchestrator.py | 503 +++++++++++------------------ src/agent/prompts.py | 71 +++++ src/agent/summarizer.py | 70 ++++ src/agent/summary_db.py | 35 ++ src/agent/toolkit.py | 94 ++++++ src/agent/tools.py | 337 -------------------- src/agent/tools/__init__.py | 0 src/agent/tools/read.py | 28 ++ src/agent/tools/search.py | 42 +++ src/agent/tools/sections.py | 46 +++ src/agent/tools/text.py | 60 ++++ src/agent/types.py | 35 ++ src/main.py | 40 ++- tests/test_benchmarks.py | 2 +- 20 files changed, 897 insertions(+), 1289 deletions(-) create mode 100644 src/agent/context.py delete mode 100644 src/agent/context_manager.py create mode 100644 src/agent/llm.py create mode 100644 src/agent/prompts.py create mode 100644 src/agent/summarizer.py create mode 100644 src/agent/summary_db.py create mode 100644 src/agent/toolkit.py delete mode 100644 src/agent/tools.py create mode 100644 src/agent/tools/__init__.py create mode 100644 src/agent/tools/read.py create mode 100644 src/agent/tools/search.py create mode 100644 src/agent/tools/sections.py create mode 100644 src/agent/tools/text.py create mode 100644 src/agent/types.py diff --git a/src/agent/__init__.py b/src/agent/__init__.py index e2b1f96c..f3e6a216 100644 --- a/src/agent/__init__.py +++ b/src/agent/__init__.py @@ -1,6 +1,7 @@ -from src.agent.context_manager import ContextRegistry +from src.agent.context import ContextRegistry from src.agent.orchestrator import AgentOrchestrator from src.agent.logger import AgentLogger +from src.agent.toolkit import AgentToolkit -__all__ = ["ContextRegistry", "AgentOrchestrator", "AgentLogger"] +__all__ = ["ContextRegistry", "AgentOrchestrator", "AgentLogger", "AgentToolkit"] diff --git a/src/agent/context.py b/src/agent/context.py new file mode 100644 index 00000000..e61a508a --- /dev/null +++ b/src/agent/context.py @@ -0,0 +1,117 @@ +from typing import Dict, List, Optional, Any +from src.agent.types import ObservationMetadata + +class ContextBudgetExceeded(Exception): + """Raised when adding an observation would exceed the max context budget.""" + pass + +class ContextRegistry: + """ + Keyed registry for agent observations. + Enforces a rough token budget (approx 3.5 chars per token). + """ + + def __init__(self, max_tokens: int = 8000): + self._observations: Dict[str, str] = {} + self._metadata: Dict[str, ObservationMetadata] = {} + self._counter: int = 0 + self._max_tokens = max_tokens + self._current_chars = 0 + + @property + def current_tokens(self) -> int: + return int(self._current_chars / 3.5) + + @property + def status(self) -> Dict[str, Any]: + used = self.current_tokens + return { + "used": used, + "total": self._max_tokens, + "usage_percent": (used / self._max_tokens) * 100 if self._max_tokens > 0 else 0.0, + "count": len(self._observations) + } + + def _check_budget(self, text: str): + new_tokens = int(len(text) / 3.5) + if (self.current_tokens + new_tokens) > self._max_tokens: + raise ContextBudgetExceeded( + f"Cannot add observation ({new_tokens} tokens). " + f"Registry full: {self.current_tokens}/{self._max_tokens} tokens used." + ) + + def add(self, text: str, step: Optional[int] = None) -> str: + self._check_budget(text) + self._counter += 1 + ref_id = f"obs_{self._counter}" + self._observations[ref_id] = text + self._metadata[ref_id] = ObservationMetadata(added_in_step=step) + self._current_chars += len(text) + return ref_id + + def remove(self, ref_id: str, step: Optional[int] = None) -> bool: + if ref_id in self._observations: + text = self._observations.pop(ref_id) + if ref_id in self._metadata: + self._metadata[ref_id].removed_in_step = step + else: + self._metadata[ref_id] = ObservationMetadata(removed_in_step=step) + self._current_chars -= len(text) + return True + return False + + def replace(self, ref_id: str, new_text: str, step: Optional[int] = None) -> None: + if ref_id not in self._observations: + raise KeyError(f"Observation {ref_id} not found.") + + old_text = self._observations[ref_id] + diff_chars = len(new_text) - len(old_text) + new_total_tokens = int((self._current_chars + diff_chars) / 3.5) + + if new_total_tokens > self._max_tokens: + raise ContextBudgetExceeded( + f"Replacement exceeds budget. Total would be {new_total_tokens}/{self._max_tokens}." + ) + + self._observations[ref_id] = new_text + if ref_id in self._metadata: + self._metadata[ref_id].replaced_in_step = step + self._metadata[ref_id].replaced_with = ref_id + self._current_chars += diff_chars + + def get(self, ref_id: str) -> Optional[str]: + return self._observations.get(ref_id) + + def list_ids(self) -> List[str]: + return list(self._observations.keys()) + + def clear(self) -> None: + self._observations.clear() + self._metadata.clear() + self._counter = 0 + self._current_chars = 0 + + def __len__(self) -> int: + return len(self._observations) + + def get_all_metadata(self) -> Dict[str, Dict[str, Any]]: + result = {} + all_ref_ids = set(self._observations.keys()) | set(self._metadata.keys()) + + for ref_id in all_ref_ids: + meta = self._metadata.get(ref_id, ObservationMetadata()) + lifecycle = [] + if meta.added_in_step is not None: lifecycle.append(f"added-in-step-{meta.added_in_step}") + if meta.removed_in_step is not None: lifecycle.append(f"removed-in-step-{meta.removed_in_step}") + if meta.replaced_in_step is not None: lifecycle.append(f"replaced-in-step-{meta.replaced_in_step}") + if meta.kept_in_final: lifecycle.append("kept-in-final-content") + + result[ref_id] = { + "content": self._observations.get(ref_id), + "lifecycle": "; ".join(lifecycle) if lifecycle else "no-events", + "added_in_step": meta.added_in_step, + "removed_in_step": meta.removed_in_step, + "replaced_in_step": meta.replaced_in_step, + "kept_in_final": meta.kept_in_final, + } + return result diff --git a/src/agent/context_manager.py b/src/agent/context_manager.py deleted file mode 100644 index 0271115a..00000000 --- a/src/agent/context_manager.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -Context registry for managing observations during agent investigation. -""" - -from typing import Dict, List, Optional - - -class ContextRegistry: - """Keyed registry for agent observations. Each tool execution returns a ref_id.""" - - def __init__(self): - self._observations: Dict[str, str] = {} - self._counter: int = 0 - - def add_observation(self, text: str) -> str: - """Add an observation and return its ref_id.""" - self._counter += 1 - ref_id = f"obs_{self._counter}" - self._observations[ref_id] = text - return ref_id - - def get(self, ref_id: str) -> Optional[str]: - """Get a single observation by ref_id.""" - return self._observations.get(ref_id) - - def get_context(self, keep_ids: List[str]) -> str: - """Return concatenated context for the specified ref_ids.""" - parts = [] - for ref_id in keep_ids: - if ref_id in self._observations: - parts.append(f"[{ref_id}]\n{self._observations[ref_id]}") - return "\n\n".join(parts) - - def prune(self, discard_ids: List[str]) -> None: - """Remove observations by ref_id.""" - for ref_id in discard_ids: - self._observations.pop(ref_id, None) - - def list_ids(self) -> List[str]: - """Return all current observation ref_ids.""" - return list(self._observations.keys()) - - def clear(self) -> None: - """Clear all observations.""" - self._observations.clear() - self._counter = 0 - - def __len__(self) -> int: - return len(self._observations) - diff --git a/src/agent/generate_summaries.py b/src/agent/generate_summaries.py index 5f2e7bce..50bbc6ea 100644 --- a/src/agent/generate_summaries.py +++ b/src/agent/generate_summaries.py @@ -1,518 +1,47 @@ -""" -Offline pipeline to: -1. Generate dense section summaries using a large "Thinking" model. -2. Build a SQLite navigation index with section/paragraph one-line summaries. - -Artifacts: -- Input: data/extracted_sections.json -- Output: data/section_summaries.json - data/nav_index.sqlite3 -""" - import json -import re -import sqlite3 from pathlib import Path -from typing import List, Dict - -from llama_cpp import Llama - -# ---------- Paths / Config ---------- +from src.agent.summarizer import ThinkingSummarizer +from src.agent.summary_db import init_db, save_section_to_db -_SCRIPT_DIR = Path(__file__).parent -_PROJECT_ROOT = _SCRIPT_DIR.parent.parent - -EXTRACTED_SECTIONS_PATH = _PROJECT_ROOT / "data" / "extracted_sections.json" -SECTION_SUMMARIES_PATH = _PROJECT_ROOT / "data" / "section_summaries.json" +_PROJECT_ROOT = Path(__file__).parent.parent.parent +EXTRACTED_PATH = _PROJECT_ROOT / "data" / "extracted_sections.json" NAV_DB_PATH = _PROJECT_ROOT / "data" / "nav_index.sqlite3" - MODEL_PATH = str(_PROJECT_ROOT / "models" / "Qwen3-30B-A3B-Q6_K.gguf") -CTX_SIZE = 40960 # matches earlier summarizer -SUMMARY_BUDGET = 500 # tokens for dense section summary -ONE_LINE_MAX_TOKENS = 64 # for single-sentence summaries -ONE_LINE_MAX_CHARS = 200 # hard cap on character length - - -# ---------- Thinking Summarizer (integrated) ---------- - -class ThinkingSummarizer: - """ - Large-model summarizer. - - - summarize_recursive(): dense section summary using rolling recursion. - - process_section(): rolling paragraph summarization with buffer. - - one_line(): single-sentence description using a separate prompt. - """ - - def __init__(self, model_path: str, n_ctx: int): - print(f"[summ] Loading model: {model_path}") - self.llm = Llama( - model_path=model_path, - n_ctx=n_ctx, - n_gpu_layers=-1, - tensor_split=[24, 24], # adjust for your GPUs - verbose=False, - n_batch=512, - ) - - # ----- shared cleanup for thinking tags ----- - - def clean_thinking_tokens(self, text: str) -> str: - """Strip / blocks while keeping the final answer.""" - if not text: - return "" - - flags = re.DOTALL | re.IGNORECASE - - # If there is a closing , keep everything AFTER the last one - m = re.search(r'\s*(.*)$', text, flags=flags) - if m: - text = m.group(1) - - # Same idea for and - for tag in ("thinking", "reasoning"): - m = re.search(rf'\s*(.*)$', text, flags=flags) - if m: - text = m.group(1) - - # Remove remaining bare tags - text = re.sub(r']*>', '', text, flags=flags) - text = re.sub(r']*>', '', text, flags=flags) - text = re.sub(r']*>', '', text, flags=flags) - - text = re.sub(r'\n\s*\n+', '\n\n', text) - cleaned = text.strip() - - # If we somehow end up empty, treat as an error instead of hiding it - if not cleaned: - raise RuntimeError("LLM returned empty text after cleaning thinking tokens.") - return cleaned - - # ----- dense rolling summary over long text ----- - - def generate_update(self, current_summary: str, new_text: str) -> str: - """ - Update the running summary with new information, staying within SUMMARY_BUDGET. - """ - prompt = f"""<|im_start|>system -You are an expert technical summarizer. You are maintaining a dense, running summary of a technical document. -<|im_end|> -<|im_start|>user -Current Summary: -{current_summary or "(None)"} - -New Content to Incorporate: -{new_text} - -Task: Update the 'Current Summary' to include key information from 'New Content'. -Constraints: -1. Keep the total length under {SUMMARY_BUDGET} tokens. -2. Do not lose previous key details. -3. Output ONLY the updated summary. Do NOT include any thinking tags, reasoning blocks, or meta-commentary. -<|im_end|> -<|im_start|>assistant -""" - output = self.llm.create_completion( - prompt, - max_tokens=SUMMARY_BUDGET + 100, - temperature=0.3, - stop=["<|im_end|>"], - ) - raw = output["choices"][0]["text"] - return self.clean_thinking_tokens(raw) - - def summarize_recursive(self, text: str, current_summary: str = "") -> str: - """ - Recursively process text chunks. - 1. If text + summary fits context, update summary. - 2. If not, split text and recurse. - """ - est_tokens = (len(text) + len(current_summary)) / 3.5 - - if est_tokens < (CTX_SIZE - SUMMARY_BUDGET - 1000): - return self.generate_update(current_summary, text) - - # Split roughly in half on a sentence boundary - mid = len(text) // 2 - split_match = re.search(r'[.!?]\s', text[mid:]) - if split_match: - split_idx = mid + split_match.end() - else: - split_idx = mid - - part1 = text[:split_idx] - part2 = text[split_idx:] - - updated_summary = self.summarize_recursive(part1, current_summary) - final_summary = self.summarize_recursive(part2, updated_summary) - return final_summary - - def process_section(self, section_text: str) -> str: - """ - Reads a section paragraph-by-paragraph to build a rolling summary. - """ - paragraphs = section_text.split("\n\n") - running_summary = "" - buffer = "" - - for i, para in enumerate(paragraphs): - para = para.strip() - if not para: - continue - - buffer += "\n" + para - - if len(buffer) > 8000: # ~2200 tokens - print(f" [summ] Processing chunk {i + 1}/{len(paragraphs)}...") - running_summary = self.summarize_recursive(buffer, running_summary) - buffer = "" - - if buffer: - running_summary = self.summarize_recursive(buffer, running_summary) - - return self.clean_thinking_tokens(running_summary) - - # ----- one-line summary ----- - - def one_line(self, text: str) -> str: - """ - Ask the model for a single-sentence one-liner. - - No fallback. If the call fails or yields nothing, an exception is raised. - """ - if not text.strip(): - raise ValueError("Cannot create one-line summary from empty text.") - - prompt = f"""<|im_start|>system -You write concise, single-sentence descriptions of technical text. -Output exactly one sentence, under {ONE_LINE_MAX_CHARS} characters. -No bullet points, no lists, no markdown, no explanations. -<|im_end|> -<|im_start|>user -Text: -{text} -<|im_end|> -<|im_start|>assistant -""" - output = self.llm.create_completion( - prompt, - max_tokens=ONE_LINE_MAX_TOKENS, - temperature=0.3, - stop=["<|im_end|>"], - ) - raw = output["choices"][0]["text"] - cleaned = self.clean_thinking_tokens(raw) - cleaned = " ".join(cleaned.split()) - - if len(cleaned) > ONE_LINE_MAX_CHARS: - cleaned = cleaned[: ONE_LINE_MAX_CHARS - 3] + "..." - - if not cleaned: - raise RuntimeError("LLM returned empty one-line summary.") - return cleaned - - -# ---------- JSON Helpers ---------- - -def load_sections() -> List[Dict]: - """Load raw sections with heading + content.""" - with open(EXTRACTED_SECTIONS_PATH, "r", encoding="utf-8") as f: - return json.load(f) - - -def split_paragraphs(text: str) -> List[str]: - """Simple paragraph splitter using double newlines.""" - paras = [p.strip() for p in text.split("\n\n")] - return [p for p in paras if p] - - -# ---------- SQLite Schema ---------- - -def init_db(conn: sqlite3.Connection) -> None: - """Create tables, dropping any existing ones.""" - cur = conn.cursor() - cur.executescript( - """ - DROP TABLE IF EXISTS paragraphs; - DROP TABLE IF EXISTS sections; - - CREATE TABLE sections ( - id INTEGER PRIMARY KEY, - heading TEXT NOT NULL, - section_summary TEXT, - one_line_summary TEXT, - prev_section_id INTEGER, - next_section_id INTEGER, - num_paragraphs INTEGER, - content_length INTEGER - ); - - CREATE TABLE paragraphs ( - id INTEGER PRIMARY KEY, - section_id INTEGER NOT NULL, - para_index INTEGER NOT NULL, - one_line_summary TEXT, - raw_text TEXT, - FOREIGN KEY(section_id) REFERENCES sections(id) - ); - - CREATE INDEX idx_paragraphs_section - ON paragraphs(section_id, para_index); - """ - ) - conn.commit() - -# ------------------ Reprocessing Single Section ------------------ - -def reprocess_single_section(section_id: int) -> None: - """ - Re-run summarization + navigation indexing for a single section id. - - Steps: - 1) Recompute dense summary for the section. - 2) Recompute one-line summary for the section. - 3) Recompute one-line summaries for all paragraphs in that section. - 4) Update nav_index.sqlite3 (sections + paragraphs for this section only). - 5) Update section_summaries.json for this section only. - - Assumes: - - extracted_sections.json is the source of truth. - - nav_index.sqlite3 and section_summaries.json already exist and cover all sections. - """ - # --- Load base sections (source of truth) --- - if not EXTRACTED_SECTIONS_PATH.exists(): - raise FileNotFoundError( - f"{EXTRACTED_SECTIONS_PATH} not found. Run extraction first." - ) - sections = load_sections() - total = len(sections) - - if section_id < 1 or section_id > total: - raise ValueError(f"section_id {section_id} out of range (1..{total})") - - sec = sections[section_id - 1] - heading = sec.get("heading", "Untitled") - content = sec.get("content", "") or "" - content_length = len(content) - - print(f"\n[reprocess] Section {section_id}/{total}: {heading} ({content_length} chars)") - - # --- Initialize summarizer (big model) --- - summarizer = ThinkingSummarizer(MODEL_PATH, n_ctx=CTX_SIZE) - - if not content.strip(): - raise ValueError(f"Section {section_id} has no content; nothing to reprocess.") - - # --- 1) Dense section summary --- - section_summary = summarizer.process_section(content) - - # --- 2) One-line section summary (from dense summary) --- - one_line_section = summarizer.one_line(section_summary) - - # --- 3) Paragraphs + their one-liners --- - paras = split_paragraphs(content) - num_paras = len(paras) - - # prev / next IDs are purely sequential, as in the original pipeline - prev_id = section_id - 1 if section_id > 1 else None - next_id = section_id + 1 if section_id < total else None - - # --- 4) Update SQLite nav index (no schema reset) --- - if not NAV_DB_PATH.exists(): - raise FileNotFoundError( - f"{NAV_DB_PATH} not found. Run the full nav-index pipeline once first." - ) - - conn = sqlite3.connect(NAV_DB_PATH) - cur = conn.cursor() - - # Make sure the expected tables exist - cur.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='sections';" - ) - if cur.fetchone() is None: - conn.close() - raise RuntimeError("SQLite DB missing 'sections' table.") - - cur.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='paragraphs';" - ) - if cur.fetchone() is None: - conn.close() - raise RuntimeError("SQLite DB missing 'paragraphs' table.") - - # Delete existing rows for this section - cur.execute("DELETE FROM paragraphs WHERE section_id = ?", (section_id,)) - cur.execute("DELETE FROM sections WHERE id = ?", (section_id,)) - - # Insert updated section row - cur.execute( - """ - INSERT INTO sections ( - id, heading, section_summary, one_line_summary, - prev_section_id, next_section_id, num_paragraphs, content_length - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - section_id, - heading, - section_summary, - one_line_section, - prev_id, - next_id, - num_paras, - content_length, - ), - ) - - # Insert updated paragraphs - for p_idx, para_text in enumerate(paras): - para_one_line = summarizer.one_line(para_text) - cur.execute( - """ - INSERT INTO paragraphs ( - section_id, para_index, one_line_summary, raw_text - ) VALUES (?, ?, ?, ?) - """, - (section_id, p_idx, para_one_line, para_text), - ) - - conn.commit() - conn.close() - - # --- 5) Update section_summaries.json entry for this section --- - if not SECTION_SUMMARIES_PATH.exists(): - raise FileNotFoundError( - f"{SECTION_SUMMARIES_PATH} not found. Run the full summary pipeline once first." - ) - - with open(SECTION_SUMMARIES_PATH, "r", encoding="utf-8") as f: - summaries_data = json.load(f) - - # Basic sanity: expect at least 'total' entries and matching index - if not isinstance(summaries_data, list): - raise RuntimeError("section_summaries.json is not a list.") - - if len(summaries_data) < total: - raise RuntimeError( - f"section_summaries.json only has {len(summaries_data)} entries, " - f"but extracted_sections has {total}." - ) - - summaries_data[section_id - 1] = { - "heading": heading, - "summary": section_summary, - "content_length": content_length, - } - - with open(SECTION_SUMMARIES_PATH, "w", encoding="utf-8") as f: - json.dump(summaries_data, f, indent=2, ensure_ascii=False) - - print(f"[reprocess] Updated section {section_id} in nav_index.sqlite3 and section_summaries.json") - - -# ---------- Main Pipeline ---------- - -def main() -> None: - if not EXTRACTED_SECTIONS_PATH.exists(): - raise FileNotFoundError( - f"{EXTRACTED_SECTIONS_PATH} not found. Run extraction first." - ) - - NAV_DB_PATH.parent.mkdir(parents=True, exist_ok=True) - - sections = load_sections() - total = len(sections) - - summarizer = ThinkingSummarizer(MODEL_PATH, n_ctx=CTX_SIZE) - conn = sqlite3.connect(NAV_DB_PATH) - init_db(conn) - cur = conn.cursor() - - section_summaries_output: List[Dict] = [] - - for idx, sec in enumerate(sections, start=1): - heading = sec.get("heading", "Untitled") - content = sec.get("content", "") or "" - content_length = len(content) - - print(f"\n[{idx}/{total}] Section: {heading} ({content_length} chars)") - - if not content.strip(): - section_summary = "" - one_line_section = "" - num_paras = 0 - else: - # 1) Dense section summary (rolling, recursive) - section_summary = summarizer.process_section(content) - - # 2) One-line section summary from the dense summary - one_line_section = summarizer.one_line(section_summary) - - # 3) Paragraphs + their one-liners - paras = split_paragraphs(content) - num_paras = len(paras) - - # prev / next IDs are purely sequential - prev_id = idx - 1 if idx > 1 else None - next_id = idx + 1 if idx < total else None - - # Insert section row - cur.execute( - """ - INSERT INTO sections ( - id, heading, section_summary, one_line_summary, - prev_section_id, next_section_id, num_paragraphs, content_length - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - idx, - heading, - section_summary, - one_line_section, - prev_id, - next_id, - num_paras, - content_length, - ), - ) - - # Insert paragraphs - if content.strip(): - paras = split_paragraphs(content) - for p_idx, para_text in enumerate(paras): - para_one_line = summarizer.one_line(para_text) - cur.execute( - """ - INSERT INTO paragraphs ( - section_id, para_index, one_line_summary, raw_text - ) VALUES (?, ?, ?, ?) - """, - (idx, p_idx, para_one_line, para_text), - ) - - conn.commit() - - # Collect for section_summaries.json - section_summaries_output.append( - { - "heading": heading, - "summary": section_summary, - "content_length": content_length, - } - ) - - # Incremental save of summaries file - with open(SECTION_SUMMARIES_PATH, "w", encoding="utf-8") as f: - json.dump(section_summaries_output, f, indent=2, ensure_ascii=False) - - conn.close() - - print(f"\n[done] Section summaries written to {SECTION_SUMMARIES_PATH}") - print(f"[done] Navigation index written to {NAV_DB_PATH}") - +def process_all(): + if not EXTRACTED_PATH.exists(): + print("No extracted sections found.") + return + + with open(EXTRACTED_PATH) as f: + sections = json.load(f) + + summarizer = ThinkingSummarizer(MODEL_PATH) + init_db(NAV_DB_PATH) + + for i, sec in enumerate(sections): + print(f"Processing section {i+1}/{len(sections)}...") + text = sec.get("content", "") + if not text: continue + + # Dense summary + dense = summarizer.summarize_recursive(text) + + # One liner + one_line = summarizer.one_line(text) + + # Paragraphs + paras = [] + for p_idx, para in enumerate(text.split("\n\n")): + if not para.strip(): continue + p_summary = summarizer.one_line(para) + paras.append({ + "index": p_idx, + "text": para, + "summary": p_summary + }) + + save_section_to_db(NAV_DB_PATH, i+1, sec.get("heading", ""), dense, one_line, paras) if __name__ == "__main__": - # reprocess_single_section(42) - main() + process_all() diff --git a/src/agent/llm.py b/src/agent/llm.py new file mode 100644 index 00000000..be0502a4 --- /dev/null +++ b/src/agent/llm.py @@ -0,0 +1,19 @@ +from typing import List, Optional +from src.generator import get_llama_model + +class AgentLLM: + def __init__(self, model_path: str): + self.model_path = model_path + + def completion(self, prompt: str, max_tokens: int = 500, temperature: float = 0.1, stop: Optional[List[str]] = None) -> str: + model = get_llama_model(self.model_path) + if stop is None: + stop = ["<|im_end|>"] + + result = model.create_completion( + prompt, + max_tokens=max_tokens, + temperature=temperature, + stop=stop, + ) + return result["choices"][0]["text"] diff --git a/src/agent/logger.py b/src/agent/logger.py index 3b640889..2f2523e8 100644 --- a/src/agent/logger.py +++ b/src/agent/logger.py @@ -1,6 +1,4 @@ -""" -Logging for agent pipeline - captures all LLM inputs/outputs. -""" +"""Minimal logging for agent pipeline.""" import json from datetime import datetime @@ -9,91 +7,36 @@ class AgentLogger: - """Logs all LLM interactions in the agent pipeline.""" + """Logs agent interactions to JSONL file.""" def __init__(self, session_id: Optional[str] = None): self.session_id = session_id or datetime.now().strftime("%Y%m%d_%H%M%S") self.logs_dir = Path("logs") / "agent" self.logs_dir.mkdir(parents=True, exist_ok=True) self.log_file = self.logs_dir / f"agent_{self.session_id}.jsonl" - self.step_count = 0 - self.query_count = 0 def _write(self, data: Dict[str, Any]) -> None: - data["timestamp"] = datetime.now().isoformat() + data["ts"] = datetime.now().isoformat() with open(self.log_file, "a", encoding="utf-8") as f: f.write(json.dumps(data, ensure_ascii=False) + "\n") - def log_session_start(self, config: Dict[str, Any]) -> None: + def log_step(self, step: int, thought: str, tool_name: Optional[str], tool_args: Dict[str, Any], result: Optional[str], success: bool) -> None: + """Log a reasoning step with full thought (no truncation).""" self._write({ - "event": "session_start", - "session_id": self.session_id, - "config": config, - }) - - def log_query_start(self, question: str) -> None: - self.query_count += 1 - self.step_count = 0 - self._write({ - "event": "query_start", - "query_id": self.query_count, - "question": question, - }) - - def log_reasoning_step( - self, - prompt: str, - response: str, - parsed_step: Optional[Dict[str, Any]], - ) -> None: - self.step_count += 1 - self._write({ - "event": "reasoning_step", - "query_id": self.query_count, - "step": self.step_count, - "prompt": prompt, - "response": response, - "parsed": parsed_step, - "parse_success": parsed_step is not None, - }) - - def log_tool_execution( - self, - tool_name: str, - tool_args: Dict[str, Any], - result: str, - ref_id: str, - ) -> None: - self._write({ - "event": "tool_execution", - "query_id": self.query_count, - "step": self.step_count, + "event": "step", + "step": step, + "thought": thought, "tool_name": tool_name, "tool_args": tool_args, "result": result, - "ref_id": ref_id, + "success": success, }) - def log_synthesis( - self, - prompt: str, - response: str, - keep_ids: list, - ) -> None: - self._write({ - "event": "synthesis", - "query_id": self.query_count, - "prompt": prompt, - "response": response, - "keep_ids": keep_ids, - }) - - def log_query_complete(self, answer: str, metadata: Dict[str, Any]) -> None: + def log_query_complete(self, question: str, answer: str, registry_metadata: Dict[str, Dict[str, Any]]) -> None: + """Log query completion with full registry lifecycle.""" self._write({ "event": "query_complete", - "query_id": self.query_count, - "total_steps": self.step_count, + "question": question, "answer": answer, - "metadata": metadata, + "registry_entries": registry_metadata, }) - diff --git a/src/agent/orchestrator.py b/src/agent/orchestrator.py index 512d29e3..da4de68f 100644 --- a/src/agent/orchestrator.py +++ b/src/agent/orchestrator.py @@ -1,347 +1,218 @@ -""" -Agent orchestrator for the dynamic context budgeted agent. - -Manages the investigation loop: -1. Investigation phase: SLM queries tools and manages context registry -2. Synthesis phase: Generate final answer from curated context -""" - import json import re -from dataclasses import dataclass -from typing import Optional, Dict, List, Any - -from src.agent.context_manager import ContextRegistry -from src.agent.tools import AgentToolkit -from src.agent.logger import AgentLogger -from src.generator import get_llama_model, ANSWER_END - - -@dataclass -class AgentConfig: - reasoning_limit: int = 5 - tool_limit: int = 20 - max_reasoning_tokens: int = 500 - max_generation_tokens: int = 400 - - -@dataclass -class AgentStep: - thought: str - tool_name: Optional[str] - tool_args: Dict[str, Any] - context_action: Dict[str, Any] - signal: str - - -AGENT_SYSTEM_PROMPT = """You are an investigative agent that retrieves information to answer questions. - -You work in a loop: think → use a tool → observe → repeat until ready. - -{tool_descriptions} - -## Output Format (strict JSON) -```json -{{ - "thought": "Your reasoning about current state and next steps", - "tool_name": "name_of_tool or null if done", - "tool_args": {{"arg1": "value1"}}, - "context_action": {{ - "keep": ["obs_1", "obs_3"], - "discard": ["obs_2"], - "notes": "Why keeping these" - }}, - "signal": "continue or finish" -}} -``` - -## Rules -- Use search_index first to find relevant chunks -- Use read_content to get full text of promising chunks -- Use grep_text for exact matches (code, variables, specific terms) -- Signal "finish" when you have enough information -- Keep only observations needed for the final answer -- Discard observations that are irrelevant or redundant - -Current observations in registry: {observation_ids} -""" - -SYNTHESIS_PROMPT = """Based on the following curated context, answer the question concisely. - -Context: -{context} - -Question: {question} - -Answer:""" - - -def parse_agent_response(text: str) -> Optional[AgentStep]: - """Extract JSON from agent response.""" - text = text.strip() - - # Try to find JSON in markdown code block - json_match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL) - if json_match: - json_str = json_match.group(1).strip() - else: - # Try to extract raw JSON object - json_match = re.search(r"\{.*\}", text, re.DOTALL) - if json_match: - json_str = json_match.group(0) - else: - return None - - try: - data = json.loads(json_str) - except json.JSONDecodeError: - return None - - return AgentStep( - thought=data.get("thought", ""), - tool_name=data.get("tool_name"), - tool_args=data.get("tool_args", {}), - context_action=data.get("context_action", {}), - signal=data.get("signal", "continue"), - ) +from typing import Optional, List, Any, Generator, Dict +from src.agent.types import AgentConfig, AgentStep +from src.agent.context import ContextRegistry, ContextBudgetExceeded +from src.agent.toolkit import AgentToolkit +from src.agent.llm import AgentLLM +from src.agent.prompts import INVESTIGATION_PROMPT_TEMPLATE, SYNTHESIS_PROMPT, AGENT_SYSTEM_PROMPT class AgentOrchestrator: - """Main agent loop coordinating tools, context, and LLM calls.""" - - def __init__( - self, - toolkit: AgentToolkit, - model_path: str, - config: Optional[AgentConfig] = None, - ): + def __init__(self, toolkit: AgentToolkit, model_path: str, config: Optional[AgentConfig] = None, logger: Optional[Any] = None): self.toolkit = toolkit - self.model_path = model_path - self.config = config or AgentConfig() - self.registry = ContextRegistry() - self.logger = AgentLogger() - self.logger.log_session_start({ - "model_path": model_path, - "reasoning_limit": self.config.reasoning_limit, - "tool_limit": self.config.tool_limit, - }) + self.config = config or AgentConfig(model_path=model_path) + self.llm = AgentLLM(model_path) + self.registry = ContextRegistry(max_tokens=self.config.max_context_tokens) + self.logger = logger + + def _parse_step(self, text: str) -> Optional[AgentStep]: + text = text.strip() + match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL) or re.search(r"\{.*\}", text, re.DOTALL) + if not match: return None + + try: + data = json.loads(match.group(1 if match.lastindex else 0)) + + # Handle hallucinated "next_step": "tool(args)" format + if "next_step" in data and not data.get("tool_name"): + call_str = data["next_step"] + # Parse tool(arg=val) + m_call = re.match(r"(\w+)\((.*)\)", call_str) + if m_call: + tool_name = m_call.group(1) + args_str = m_call.group(2) + args = {} + # Simple arg parser for key=value or key="value" + # This is brittle but handles the example seen + for arg_pair in args_str.split(","): + if "=" in arg_pair: + k, v = arg_pair.split("=", 1) + k = k.strip() + v = v.strip().strip("'").strip('"') + # Try to convert to int/float + try: + if "." in v: v = float(v) + else: v = int(v) + except ValueError: + pass + args[k] = v + + return AgentStep( + thought=data.get("thought", f"Decided to call {tool_name}"), + tool_name=tool_name, + tool_args=args, + context_action={}, + signal="continue" + ) + + return AgentStep( + thought=data.get("thought", "Recovering from simplified output..."), + tool_name=data.get("tool_name"), + tool_args=data.get("tool_args", {}), + context_action=data.get("context_action", {}), + signal=data.get("signal", "continue"), + ) + except Exception as e: + print(f"JSON Parse Error: {e}") + return None - def _build_investigation_prompt(self, question: str, history: List[str]) -> str: - """Build prompt for investigation step.""" + def _build_prompt(self, question: str, history: List[str]) -> str: obs_ids = self.registry.list_ids() - current_obs = self.registry.get_context(obs_ids) if obs_ids else "None yet." - - system = AGENT_SYSTEM_PROMPT.format( - tool_descriptions=AgentToolkit.get_tool_descriptions(), - observation_ids=obs_ids or "[]", - ) - - history_text = "\n".join(history) if history else "No history yet." - - return f"""<|im_start|>system -{system} - -Current observations: -{current_obs} -<|im_end|> -<|im_start|>user -Question: {question} - -Investigation history: -{history_text} - -What's your next step? -<|im_end|> -<|im_start|>assistant -```json -""" - - def _run_reasoning_step(self, prompt: str) -> str: - """Run a single LLM call for reasoning.""" - model = get_llama_model(self.model_path) - result = model.create_completion( - prompt, - max_tokens=self.config.max_reasoning_tokens, - temperature=0.1, - stop=["```\n", "<|im_end|>"], - ) - return result["choices"][0]["text"] - - def _apply_context_action(self, action: Dict[str, Any]) -> None: - """Apply keep/discard actions to the registry.""" - discard_ids = action.get("discard", []) - if discard_ids: - self.registry.prune(discard_ids) - - def investigate(self, question: str) -> List[str]: - """ - Run investigation phase. - Returns list of observation IDs to use for synthesis. - """ - history: List[str] = [] - reasoning_count = 0 - tool_count = 0 - - while reasoning_count < self.config.reasoning_limit: - reasoning_count += 1 - - prompt = self._build_investigation_prompt(question, history) - response = self._run_reasoning_step(prompt) - - step = parse_agent_response(response) - if step is None: - history.append(f"Step {reasoning_count}: [Parse error] {response[:200]}") - continue - - history.append(f"Step {reasoning_count}: {step.thought}") - - if step.signal == "finish" or step.tool_name is None: - keep_ids = step.context_action.get("keep", self.registry.list_ids()) - return keep_ids - - if tool_count >= self.config.tool_limit: - history.append(f"Tool limit ({self.config.tool_limit}) reached.") - return step.context_action.get("keep", self.registry.list_ids()) - - tool_count += 1 - observation = self.toolkit.execute(step.tool_name, step.tool_args) - ref_id = self.registry.add_observation(observation) - history.append(f" Tool: {step.tool_name}({step.tool_args}) → {ref_id}") - - self._apply_context_action(step.context_action) - - return self.registry.list_ids() - - def _build_synthesis_prompt(self, question: str, keep_ids: List[str]) -> str: - """Build prompt for synthesis step.""" - context = self.registry.get_context(keep_ids) - return f"""<|im_start|>system -You are a helpful assistant. Answer questions based on the provided context. -<|im_end|> -<|im_start|>user -{SYNTHESIS_PROMPT.format(context=context, question=question)} -<|im_end|> -<|im_start|>assistant -""" - - def synthesize(self, question: str, keep_ids: List[str]) -> str: - """Generate final answer from curated context.""" - prompt = self._build_synthesis_prompt(question, keep_ids) - model = get_llama_model(self.model_path) - result = model.create_completion( - prompt, - max_tokens=self.config.max_generation_tokens, - temperature=0.2, - stop=[ANSWER_END, "<|im_end|>"], + + # Format active context + active_ctx = [] + read_chunk_ids = set() + for ref_id in obs_ids: + content = self.registry.get(ref_id) + if content: + active_ctx.append(f"[{ref_id}]\n{content}") + if "Content from chunks" in content or "Chunk " in content: + # Simple heuristic to track seen chunks + matches = re.findall(r"Chunk (\d+)", content) + read_chunk_ids.update(int(c) for c in matches) + + full_context = "\n\n".join(active_ctx) if active_ctx else "No observations yet." + + # Lifecycle metadata + meta = self.registry.get_all_metadata() + lifecycle = [f"{k}: {v['lifecycle']}" for k, v in meta.items() if v['lifecycle'] != "no-events"] + + status = self.registry.status + system_text = AGENT_SYSTEM_PROMPT.format( + tool_descriptions=self.toolkit.get_tool_descriptions(), + observation_ids=str(obs_ids), + budget_status=f"{status['used']}/{status['total']} tokens" ) - return result["choices"][0]["text"].strip() - - def _synthesize_with_logging(self, question: str, keep_ids: List[str]) -> str: - """Synthesize with logging.""" - prompt = self._build_synthesis_prompt(question, keep_ids) - model = get_llama_model(self.model_path) - result = model.create_completion( - prompt, - max_tokens=self.config.max_generation_tokens, - temperature=0.2, - stop=[ANSWER_END, "<|im_end|>"], + + return INVESTIGATION_PROMPT_TEMPLATE.format( + system=system_text, + question=question, + full_context=full_context, + read_chunks_str=str(sorted(list(read_chunk_ids))), + lifecycle_str="\n".join(lifecycle) if lifecycle else "None", + history_text="\n".join(history[-10:]) or "None" ) - response = result["choices"][0]["text"].strip() - self.logger.log_synthesis(prompt, response, keep_ids) - return response - - def run(self, question: str) -> Dict[str, Any]: - """ - Full agent run: investigate → synthesize. - Returns dict with answer, observations, and metadata. - """ - self.registry.clear() - - keep_ids = self.investigate(question) - answer = self.synthesize(question, keep_ids) - return { - "answer": answer, - "kept_observations": keep_ids, - "total_observations": len(self.registry), - "context_used": self.registry.get_context(keep_ids), - } - - def stream_run(self, question: str): + def stream_run(self, question: str) -> Generator[Dict[str, Any], None, None]: """ - Generator version of run for streaming output. - Yields status updates during investigation, then final answer. + Run investigation and synthesis, yielding events. + Events: + - {"type": "thought", "step": int, "thought": str} + - {"type": "tool", "tool_name": str, "tool_args": dict} + - {"type": "answer", "answer": str, "kept_observations": List[str]} """ self.registry.clear() + + # Seed initial search + res, _ = self.toolkit.get_initial_context(question) + self.registry.add(f"Initial search: {res}", step=0) - yield {"type": "status", "message": "Starting investigation..."} - - history: List[str] = [] - reasoning_count = 0 - tool_count = 0 + history = [] + steps = 0 keep_ids = [] - - while reasoning_count < self.config.reasoning_limit: - reasoning_count += 1 - - prompt = self._build_investigation_prompt(question, history) - response = self._run_reasoning_step(prompt) - - step = parse_agent_response(response) - parsed_dict = { - "thought": step.thought, - "tool_name": step.tool_name, - "tool_args": step.tool_args, - "context_action": step.context_action, - "signal": step.signal, - } if step else None - self.logger.log_reasoning_step(prompt, response, parsed_dict) - - if step is None: - history.append(f"Step {reasoning_count}: [Parse error] {response[:100]}") - yield {"type": "status", "message": f"Parse error at step {reasoning_count}, retrying..."} + + # --- Investigation Phase --- + while steps < self.config.reasoning_limit: + steps += 1 + # print(f"--- Step {steps} ---") + prompt = self._build_prompt(question, history) + + response = self.llm.completion(prompt, max_tokens=self.config.max_reasoning_tokens) + print(f"\n[DEBUG RAW RESPONSE]\n{response}\n[END DEBUG]\n") + step = self._parse_step(response) + + if not step: + # print("Failed to parse response") + history.append(f"Step {steps}: [Use stricter JSON format]") + continue + + # Yield thought event + yield { + "type": "thought", + "step": steps, + "thought": step.thought + } + + # print(f"Thought: {step.thought}") + # print(f"Tool: {step.tool_name}") + + history.append(f"Step {steps}: {step.thought} (Tool: {step.tool_name})") + if self.logger: + self.logger.log_step(steps, step.thought, step.tool_name, step.tool_args, None, True) + + # Handle context actions + if step.context_action: + for ref_id in step.context_action.get("discard", []): + self.registry.remove(ref_id, step=steps) + + # Finish logic + if step.signal == "finish" or not step.tool_name: + # Basic check: do we have any content? + has_content = any("Chunk" in (self.registry.get(oid) or "") for oid in self.registry.list_ids()) + if has_content or steps >= self.config.reasoning_limit: + keep_ids = step.context_action.get("keep", self.registry.list_ids()) + break + + # If trying to finish without content, force one more step + history.append("System: You have search results but no full content. Read chunks before finishing.") continue - yield {"type": "thought", "step": reasoning_count, "thought": step.thought} - history.append(f"Step {reasoning_count}: {step.thought}") - - if step.signal == "finish" or step.tool_name is None: - keep_ids = step.context_action.get("keep", self.registry.list_ids()) - break - - if tool_count >= self.config.tool_limit: - keep_ids = step.context_action.get("keep", self.registry.list_ids()) - break - - tool_count += 1 + # Yield tool event yield { "type": "tool", "tool_name": step.tool_name, - "tool_args": step.tool_args, + "tool_args": step.tool_args } - observation = self.toolkit.execute(step.tool_name, step.tool_args) - ref_id = self.registry.add_observation(observation) - self.logger.log_tool_execution(step.tool_name, step.tool_args, observation, ref_id) - history.append(f" Tool: {step.tool_name} → {ref_id}") - - self._apply_context_action(step.context_action) + # Execute tool + res, success = self.toolkit.execute(step.tool_name, step.tool_args) + try: + self.registry.add(f"Tool {step.tool_name} result:\n{res}", step=steps) + except ContextBudgetExceeded: + history.append("System: Context full. Discard irrelevant observations.") if not keep_ids: keep_ids = self.registry.list_ids() - yield {"type": "status", "message": "Generating answer..."} - - answer = self._synthesize_with_logging(question, keep_ids) - - self.logger.log_query_complete(answer, { - "kept_observations": keep_ids, - "total_observations": len(self.registry), - }) - + # --- Synthesis Phase --- + parts = [] + for ref_id in keep_ids: + c = self.registry.get(ref_id) + if c: parts.append(f"[{ref_id}]\n{c}") + final_context = "\n\n".join(parts) + + prompt = SYNTHESIS_PROMPT.format(context=final_context, question=question) + answer_text = self.llm.completion(prompt, max_tokens=self.config.max_generation_tokens) + yield { "type": "answer", - "answer": answer, - "kept_observations": keep_ids, + "answer": answer_text, + "kept_observations": keep_ids } + def investigate(self, question: str) -> List[str]: + # Legacy/Support method wrapping stream_run + last_ids = [] + for event in self.stream_run(question): + if event["type"] == "answer": + last_ids = event["kept_observations"] + return last_ids + + def answer(self, question: str) -> str: + # Legacy/Support method wrapping stream_run + ans = "" + for event in self.stream_run(question): + if event["type"] == "answer": + ans = event["answer"] + return ans diff --git a/src/agent/prompts.py b/src/agent/prompts.py new file mode 100644 index 00000000..befcafbc --- /dev/null +++ b/src/agent/prompts.py @@ -0,0 +1,71 @@ +AGENT_SYSTEM_PROMPT = """You are an investigative agent that retrieves information to answer questions about databases and SQL. + +You work in a loop: think → use a tool → observe → repeat until ready. + +{tool_descriptions} + +## Output Format (strict JSON) +```json +{{ + "thought": "Your reasoning about what information you need", + "tool_name": "name_of_tool or null if done", + "tool_args": {{"arg1": "value1"}}, + "context_action": {{ + "keep": ["obs_1", "obs_3"], + "discard": ["obs_2"] + }}, + "signal": "continue or finish" +}} +``` + +## Critical Rules +1. **Investigate Deeply**: Search results are just previews. You MUST read actual chunk content using `read_content`. +2. **Read Content**: Never finish with only search results. Read at least 2-3 chunks. +3. **Explore**: + - Check chunks from different search results. + - Use `read_content` with relative_start/end to see surrounding text. + - Use `grep_text` if you need exact keyword matches in the full text. +4. **Finish Conditions**: + - You have read actual chunk content (not just search previews). + - The information is sufficient to answer the question. + +Current observations: {observation_ids} +Budget: {budget_status} +""" + +INVESTIGATION_PROMPT_TEMPLATE = """<|im_start|>system +{system} +Question: {question} + +=== FULL ACTIVE CONTEXT (all information you currently have) === +{full_context} + +=== SUMMARY === +{read_chunks_str} +Observations lifecycle: {lifecycle_str} + +=== RECENT STEPS === +{history_text} + +What's your next step? +- You must read actual chunk content using read_content before finishing. + +RESPONSE FORMAT (JSON ONLY): +{{ + "thought": "reasoning...", + "tool_name": "tool_name", + "tool_args": {{ "arg": "value" }}, + "signal": "continue" +}} +<|im_end|> +<|im_start|>assistant +""" + +SYNTHESIS_PROMPT = """Answer the question based on the following context. + +Context: +{context} + +Question: {question} + +Answer:""" diff --git a/src/agent/summarizer.py b/src/agent/summarizer.py new file mode 100644 index 00000000..52a35cc4 --- /dev/null +++ b/src/agent/summarizer.py @@ -0,0 +1,70 @@ +import re +from typing import Optional +from llama_cpp import Llama + +class ThinkingSummarizer: + """Large-model summarizer with thinking tag cleaning.""" + + def __init__(self, model_path: str, n_ctx: int = 40960): + self.llm = Llama( + model_path=model_path, + n_ctx=n_ctx, + n_gpu_layers=-1, + tensor_split=[24, 24], + verbose=False, + n_batch=512, + ) + + def clean_thinking_tokens(self, text: str) -> str: + """Strip / blocks while keeping the final answer.""" + if not text: return "" + flags = re.DOTALL | re.IGNORECASE + + # Keep everything after last closing tag + for tag in ["think", "thinking", "reasoning"]: + m = re.search(rf'\s*(.*)$', text, flags=flags) + if m: text = m.group(1) + text = re.sub(rf']*>', '', text, flags=flags) + + text = re.sub(r'\n\s*\n+', '\n\n', text) + return text.strip() + + def generate_update(self, current_summary: str, new_text: str, budget: int = 500) -> str: + prompt = f"""<|im_start|>system +You are an expert technical summarizer maintaining a dense running summary. +<|im_end|> +<|im_start|>user +Current Summary: {current_summary or "(None)"} +New Content: {new_text} +Task: Update summary. +Constraints: length < {budget} tokens. No thinking tags. +<|im_end|> +<|im_start|>assistant +""" + output = self.llm.create_completion(prompt, max_tokens=budget + 100, temperature=0.3, stop=["<|im_end|>"]) + return self.clean_thinking_tokens(output["choices"][0]["text"]) + + def summarize_recursive(self, text: str, current_summary: str = "", budget: int = 500) -> str: + # Simple recursion based on length + est_tokens = (len(text) + len(current_summary)) / 3.5 + if est_tokens < (40960 - budget - 1000): + return self.generate_update(current_summary, text, budget) + + mid = len(text) // 2 + # Simple split + part1, part2 = text[:mid], text[mid:] + updated = self.summarize_recursive(part1, current_summary, budget) + return self.summarize_recursive(part2, updated, budget) + + def one_line(self, text: str) -> str: + if not text.strip(): raise ValueError("Empty text") + prompt = f"""<|im_start|>system +Write one concise sentence description (< 200 chars). +<|im_end|> +<|im_start|>user +Text: {text} +<|im_end|> +<|im_start|>assistant +""" + output = self.llm.create_completion(prompt, max_tokens=64, temperature=0.3, stop=["<|im_end|>"]) + return self.clean_thinking_tokens(output["choices"][0]["text"]) diff --git a/src/agent/summary_db.py b/src/agent/summary_db.py new file mode 100644 index 00000000..1cc99239 --- /dev/null +++ b/src/agent/summary_db.py @@ -0,0 +1,35 @@ +import sqlite3 +from pathlib import Path +from typing import List, Dict + +def init_db(db_path: Path) -> None: + with sqlite3.connect(db_path) as conn: + conn.executescript(""" + DROP TABLE IF EXISTS paragraphs; + DROP TABLE IF EXISTS sections; + CREATE TABLE sections ( + id INTEGER PRIMARY KEY, heading TEXT, section_summary TEXT, + one_line_summary TEXT, prev_section_id INTEGER, next_section_id INTEGER, + num_paragraphs INTEGER, content_length INTEGER + ); + CREATE TABLE paragraphs ( + id INTEGER PRIMARY KEY, section_id INTEGER, para_index INTEGER, + one_line_summary TEXT, raw_text TEXT, + FOREIGN KEY(section_id) REFERENCES sections(id) + ); + CREATE INDEX idx_paragraphs_section ON paragraphs(section_id, para_index); + """) + +def save_section_to_db(db_path: Path, section_id: int, heading: str, summary: str, one_line: str, paras: List[Dict]): + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + cur.execute( + "INSERT INTO sections (id, heading, section_summary, one_line_summary, content_length) VALUES (?, ?, ?, ?, ?)", + (section_id, heading, summary, one_line, 0) # simplified + ) + for p in paras: + cur.execute( + "INSERT INTO paragraphs (section_id, para_index, one_line_summary, raw_text) VALUES (?, ?, ?, ?)", + (section_id, p['index'], p['summary'], p['text']) + ) + conn.commit() diff --git a/src/agent/toolkit.py b/src/agent/toolkit.py new file mode 100644 index 00000000..fbc14c75 --- /dev/null +++ b/src/agent/toolkit.py @@ -0,0 +1,94 @@ +from typing import List, Dict, Tuple, Optional +from pathlib import Path +import faiss + +from src.agent.tools.search import IndexScout +from src.agent.tools.read import NavigationalReader +from src.agent.tools.text import GrepSearch +from src.agent.tools.sections import SectionSummarizer + +class AgentToolkit: + """Container for agent tools.""" + + def __init__( + self, + faiss_index: faiss.Index, + chunks: List[str], + sources: List[str], + embed_model: str, + markdown_path: Optional[str] = None, + summaries_path: Optional[str] = None, + ): + self.index_scout = IndexScout(faiss_index, chunks, sources, embed_model) + self.reader = NavigationalReader(chunks, sources) + self.grep: Optional[GrepSearch] = None + self.summarizer: Optional[SectionSummarizer] = None + + if markdown_path and Path(markdown_path).exists(): + self.grep = GrepSearch(Path(markdown_path)) + + if summaries_path and Path(summaries_path).exists(): + self.summarizer = SectionSummarizer(Path(summaries_path)) + + self._available_tools = ["search_index", "read_content"] + if self.grep: self._available_tools.append("grep_text") + if self.summarizer: self._available_tools.extend(["get_section_summary", "list_sections"]) + + @property + def available_tools(self) -> List[str]: + return self._available_tools + + def get_tool_descriptions(self) -> str: + # Construct dynamic tool descriptions based on availability + base = """ +Available Tools: +- `search_index(query="...")`: Semantic search. Returns relevant chunk IDs vs preview text. +- `read_content(target_chunk_id=123, relative_start=-1, relative_end=1)`: Read full text of chunks. Use relative_start/end to read surrounding context. +""" + if self.grep: + base += '- `grep_text(pattern="regex")`: Search for exact patterns in the full text.\n' + if self.summarizer: + base += '- `list_sections(limit=30)`: List available document sections.\n' + base += '- `get_section_summary(section_name="...")`: Get summary of a specific section.\n' + return base + + def get_initial_context(self, question: str, top_k: int = 5) -> Tuple[str, bool]: + """Helper to get initial search results.""" + results = self.index_scout.search(question, top_k=top_k) + return self.index_scout.format_result(results), True + + def execute(self, tool_name: str, tool_args: Dict) -> Tuple[str, bool]: + if tool_name not in self._available_tools: + return f"Unknown tool '{tool_name}'. Available: {', '.join(self._available_tools)}", False + + try: + if tool_name == "search_index": + return self.index_scout.format_result( + self.index_scout.search(tool_args.get("query"), top_k=tool_args.get("top_k", 10)) + ), True + + elif tool_name == "read_content": + text, chunk_ids = self.reader.read( + tool_args.get("target_chunk_id"), + tool_args.get("relative_start", 0), + tool_args.get("relative_end", 0), + ) + return self.reader.format_result(text, chunk_ids), True + + elif tool_name == "grep_text" and self.grep: + return self.grep.format_result( + self.grep.search(tool_args.get("pattern"), tool_args.get("context_lines", 2)) + ), True + + elif tool_name == "list_sections" and self.summarizer: + return "\n".join(self.summarizer.list_sections(tool_args.get("limit", 30))), True + + elif tool_name == "get_section_summary" and self.summarizer: + return self.summarizer.format_result( + self.summarizer.get_summary(tool_args.get("section_name", "")) + ), True + + except Exception as e: + return f"Tool execution failed: {str(e)}", False + + return "Tool not available or arguments invalid.", False diff --git a/src/agent/tools.py b/src/agent/tools.py deleted file mode 100644 index 02b730d3..00000000 --- a/src/agent/tools.py +++ /dev/null @@ -1,337 +0,0 @@ -""" -Agent tools for dynamic context retrieval. - -Tools: -- IndexScout: Semantic search returning metadata (chunk IDs, relevance) -- NavigationalReader: Read chunk slices with relative offsets -- GrepSearch: Regex search across raw markdown -- SectionSummarizer: Get section content by name -""" - -import json -import re -from dataclasses import dataclass -from pathlib import Path -from typing import List, Dict, Optional, Tuple - -import faiss -import numpy as np - -from src.embedder import SentenceTransformer -from src.retriever import _get_embedder - - -@dataclass -class ChunkMetadata: - chunk_id: int - score: float - source: str - preview: str - - -@dataclass -class GrepMatch: - line_number: int - content: str - context_before: List[str] - context_after: List[str] - - -class IndexScout: - """Semantic search that returns structured metadata.""" - - def __init__( - self, - faiss_index: faiss.Index, - chunks: List[str], - sources: List[str], - embed_model: str, - ): - self.faiss_index = faiss_index - self.chunks = chunks - self.sources = sources - self.embedder = _get_embedder(embed_model) - - def search_index(self, query: str, top_k: int = 10) -> List[ChunkMetadata]: - """Search index and return metadata list without full chunk text.""" - q_vec = self.embedder.encode([query]).astype("float32") - - if q_vec.shape[1] != self.faiss_index.d: - raise ValueError( - f"Embedding dim mismatch: index={self.faiss_index.d} vs query={q_vec.shape[1]}" - ) - - distances, indices = self.faiss_index.search(q_vec, top_k) - - results = [] - for idx, dist in zip(indices[0], distances[0]): - if idx < 0 or idx >= len(self.chunks): - continue - score = 1.0 / (1.0 + float(dist)) - preview = self.chunks[idx][:150].replace("\n", " ") - results.append( - ChunkMetadata( - chunk_id=int(idx), - score=score, - source=self.sources[idx] if idx < len(self.sources) else "unknown", - preview=preview, - ) - ) - return results - - def format_result(self, results: List[ChunkMetadata]) -> str: - """Format as machine-readable structured output.""" - if not results: - return "No results found." - lines = ["Search results (use chunk_id for read_content):"] - for i, r in enumerate(results): - lines.append(f" [{i}] chunk_id={r.chunk_id} score={r.score:.3f} source={r.source}") - lines.append(f" preview: {r.preview}") - return "\n".join(lines) - - -class NavigationalReader: - """Read chunks with relative offset navigation.""" - - def __init__(self, chunks: List[str], sources: List[str]): - self.chunks = chunks - self.sources = sources - - def read_content( - self, - target_chunk_id: int, - relative_start: int = 0, - relative_end: int = 0, - ) -> Tuple[str, List[int]]: - """ - Fetch chunks[target + start : target + end + 1]. - Returns (concatenated_text, list_of_chunk_ids). - """ - start_idx = max(0, target_chunk_id + relative_start) - end_idx = min(len(self.chunks), target_chunk_id + relative_end + 1) - - if start_idx >= len(self.chunks) or end_idx <= start_idx: - return "", [] - - chunk_ids = list(range(start_idx, end_idx)) - texts = [] - for cid in chunk_ids: - src = self.sources[cid] if cid < len(self.sources) else "unknown" - texts.append(f"--- Chunk {cid} (source: {src}) ---\n{self.chunks[cid]}") - - return "\n\n".join(texts), chunk_ids - - def format_result(self, text: str, chunk_ids: List[int]) -> str: - """Format for agent consumption.""" - if not text: - return "ERROR: No content found for specified range." - return f"Content from chunks {chunk_ids}:\n\n{text}" - - -class GrepSearch: - """Regex search across raw markdown content.""" - - def __init__(self, markdown_path: str): - self.markdown_path = Path(markdown_path) - self._lines: Optional[List[str]] = None - - def _load_lines(self) -> List[str]: - if self._lines is None: - with open(self.markdown_path, "r", encoding="utf-8") as f: - self._lines = f.readlines() - return self._lines - - def grep_text( - self, - pattern: str, - context_lines: int = 2, - max_matches: int = 10, - ) -> List[GrepMatch]: - """ - Search for pattern in markdown file. - Returns matches with surrounding context. - """ - lines = self._load_lines() - try: - compiled = re.compile(pattern, re.IGNORECASE) - except re.error as e: - raise ValueError(f"Invalid regex pattern: {pattern} ({e})") - - matches = [] - - for i, line in enumerate(lines): - if compiled.search(line): - start = max(0, i - context_lines) - end = min(len(lines), i + context_lines + 1) - matches.append( - GrepMatch( - line_number=i + 1, - content=line.rstrip(), - context_before=[l.rstrip() for l in lines[start:i]], - context_after=[l.rstrip() for l in lines[i + 1 : end]], - ) - ) - if len(matches) >= max_matches: - break - - return matches - - def format_result(self, matches: List[GrepMatch]) -> str: - """Format grep results for agent.""" - if not matches: - return f"No matches found for pattern." - lines = [f"Found {len(matches)} matches:"] - for m in matches: - lines.append(f"\n Line {m.line_number}: {m.content}") - if m.context_before: - for ctx in m.context_before[-2:]: - lines.append(f" (before) {ctx}") - if m.context_after: - for ctx in m.context_after[:2]: - lines.append(f" (after) {ctx}") - return "\n".join(lines) - - -class SectionSummarizer: - """Retrieve section summaries from generated summaries file.""" - - def __init__(self, summaries_path: str): - self.summaries_path = Path(summaries_path) - self._summaries: Optional[List[Dict]] = None - - def _load_summaries(self) -> List[Dict]: - if self._summaries is None: - if not self.summaries_path.exists(): - raise FileNotFoundError( - f"Summaries file not found: {self.summaries_path}\n" - "Run: python -m src.agent.generate_summaries" - ) - with open(self.summaries_path, "r", encoding="utf-8") as f: - self._summaries = json.load(f) - return self._summaries - - def get_section_summary(self, section_name: str) -> Optional[Dict]: - """Find section by name and return its summary.""" - summaries = self._load_summaries() - section_name_lower = section_name.lower() - - for summ in summaries: - heading = summ.get("heading", "") - if section_name_lower in heading.lower(): - return { - "heading": heading, - "summary": summ.get("summary", ""), - "content_length": summ.get("content_length", 0), - } - return None - - def list_sections(self, limit: int = 30) -> List[str]: - """List available section headings with summaries.""" - summaries = self._load_summaries() - results = [] - for s in summaries[:limit]: - heading = s.get("heading", "Untitled") - summary = s.get("summary", "") - if summary: - results.append(f"{heading}: {summary[:100]}") - else: - results.append(heading) - return results - - def format_result(self, result: Optional[Dict]) -> str: - """Format section summary for agent.""" - if result is None: - return "ERROR: Section not found. Use list_sections to see available sections." - return f"Section: {result['heading']}\nSummary: {result['summary']}\nFull length: {result['content_length']} chars" - - -class AgentToolkit: - """Container for all agent tools, initialized from artifacts.""" - - def __init__( - self, - faiss_index: faiss.Index, - chunks: List[str], - sources: List[str], - embed_model: str, - markdown_path: str, - summaries_path: str, - ): - self.index_scout = IndexScout(faiss_index, chunks, sources, embed_model) - self.reader = NavigationalReader(chunks, sources) - self.grep = GrepSearch(markdown_path) - self.summarizer = SectionSummarizer(summaries_path) - - def execute(self, tool_name: str, tool_args: Dict) -> str: - """Execute a tool by name with given arguments.""" - try: - if tool_name == "search_index": - results = self.index_scout.search_index( - query=tool_args["query"], - top_k=tool_args.get("top_k", 10), - ) - return self.index_scout.format_result(results) - - elif tool_name == "read_content": - text, chunk_ids = self.reader.read_content( - target_chunk_id=tool_args["target_chunk_id"], - relative_start=tool_args.get("relative_start", 0), - relative_end=tool_args.get("relative_end", 0), - ) - return self.reader.format_result(text, chunk_ids) - - elif tool_name == "grep_text": - matches = self.grep.grep_text( - pattern=tool_args["pattern"], - context_lines=tool_args.get("context_lines", 2), - max_matches=tool_args.get("max_matches", 10), - ) - return self.grep.format_result(matches) - - elif tool_name == "get_section_summary": - result = self.summarizer.get_section_summary( - section_name=tool_args["section_name"] - ) - return self.summarizer.format_result(result) - - elif tool_name == "list_sections": - sections = self.summarizer.list_sections( - limit=tool_args.get("limit", 30) - ) - return "Available sections:\n" + "\n".join(sections) - - else: - return f"ERROR: Unknown tool '{tool_name}'. Available: search_index, read_content, grep_text, get_section_summary, list_sections" - - except Exception as e: - return f"ERROR executing {tool_name}: {type(e).__name__}: {str(e)}" - - @staticmethod - def get_tool_descriptions() -> str: - """Return tool descriptions for the agent prompt.""" - return """Available tools: - -1. search_index(query: str, top_k: int = 10) - - Returns: Structured list with chunk_id, score, source, preview - - Use chunk_id from results for read_content - - Best for: Finding relevant content by semantic similarity - -2. read_content(target_chunk_id: int, relative_start: int = 0, relative_end: int = 0) - - Returns: Full text of specified chunk range - - relative_start=-1, relative_end=1 reads 3 chunks (before, target, after) - - Use chunk_id from search_index results - -3. grep_text(pattern: str, context_lines: int = 2, max_matches: int = 10) - - Returns: Line numbers and matches with context - - Use for: Exact terms, code snippets, specific phrases - - Pattern is case-insensitive regex - -4. get_section_summary(section_name: str) - - Returns: AI-generated summary of section - - Use for: Quick overview before reading full content - - Partial match on section heading - -5. list_sections(limit: int = 30) - - Returns: Available section headings with brief summaries - - Use for: Understanding document structure""" - diff --git a/src/agent/tools/__init__.py b/src/agent/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agent/tools/read.py b/src/agent/tools/read.py new file mode 100644 index 00000000..afd2d295 --- /dev/null +++ b/src/agent/tools/read.py @@ -0,0 +1,28 @@ +from typing import List, Tuple + +class NavigationalReader: + """Read chunks with relative offset navigation.""" + + def __init__(self, chunks: List[str], sources: List[str]): + self.chunks = chunks + self.sources = sources + + def read(self, target_chunk_id: int, relative_start: int = 0, relative_end: int = 0) -> Tuple[str, List[int]]: + start_idx = max(0, target_chunk_id + relative_start) + end_idx = min(len(self.chunks), target_chunk_id + relative_end + 1) + + if start_idx >= len(self.chunks) or end_idx <= start_idx: + return "", [] + + chunk_ids = list(range(start_idx, end_idx)) + texts = [] + for cid in chunk_ids: + src = self.sources[cid] if cid < len(self.sources) else "unknown" + texts.append(f"--- Chunk {cid} (source: {src}) ---\n{self.chunks[cid]}") + + return "\n\n".join(texts), chunk_ids + + def format_result(self, text: str, chunk_ids: List[int]) -> str: + if not text: + return "No content found for specified range." + return f"Content from chunks {chunk_ids}:\n\n{text}" diff --git a/src/agent/tools/search.py b/src/agent/tools/search.py new file mode 100644 index 00000000..9d0e36d0 --- /dev/null +++ b/src/agent/tools/search.py @@ -0,0 +1,42 @@ +from typing import List +import faiss +from src.agent.types import ChunkMetadata +from src.retriever import _get_embedder + +class IndexScout: + """Semantic search returning structured metadata.""" + + def __init__(self, faiss_index: faiss.Index, chunks: List[str], sources: List[str], embed_model: str): + self.faiss_index = faiss_index + self.chunks = chunks + self.sources = sources + self.embedder = _get_embedder(embed_model) + + def search(self, query: str, top_k: int = 10) -> List[ChunkMetadata]: + q_vec = self.embedder.encode([query]).astype("float32") + distances, indices = self.faiss_index.search(q_vec, top_k) + + results = [] + for idx, dist in zip(indices[0], distances[0]): + if idx < 0 or idx >= len(self.chunks): + continue + score = 1.0 / (1.0 + float(dist)) + preview = self.chunks[idx] + results.append( + ChunkMetadata( + chunk_id=int(idx), + score=score, + source=self.sources[idx] if idx < len(self.sources) else "unknown", + full_text=preview, + ) + ) + return results + + def format_result(self, results: List[ChunkMetadata]) -> str: + if not results: + return "No results found." + lines = ["Search results (use chunk_id with read_content):"] + for i, r in enumerate(results): + lines.append(f" [{i}] chunk_id={r.chunk_id} score={r.score:.3f} source={r.source}") + lines.append(f" preview: {r.full_text}") + return "\n".join(lines) diff --git a/src/agent/tools/sections.py b/src/agent/tools/sections.py new file mode 100644 index 00000000..03adf4ca --- /dev/null +++ b/src/agent/tools/sections.py @@ -0,0 +1,46 @@ +import json +from pathlib import Path +from typing import List, Dict, Optional + +class SectionSummarizer: + """Retrieve section summaries from pre-generated file.""" + + def __init__(self, summaries_path: Path): + self.summaries_path = summaries_path + self._summaries: Optional[List[Dict]] = None + + def _load_summaries(self) -> List[Dict]: + if self._summaries is None: + with open(self.summaries_path, "r", encoding="utf-8") as f: + self._summaries = json.load(f) + return self._summaries + + def get_summary(self, section_name: str) -> Optional[Dict]: + summaries = self._load_summaries() + section_lower = section_name.lower() + for s in summaries: + heading = s.get("heading", "") + if section_lower in heading.lower(): + return { + "heading": heading, + "summary": s.get("summary", ""), + "content_length": s.get("content_length", 0), + } + return None + + def list_sections(self, limit: int = 30) -> List[str]: + summaries = self._load_summaries() + results = [] + for s in summaries[:limit]: + heading = s.get("heading", "Untitled") + summary = s.get("summary", "") + if summary: + results.append(f"{heading}: {summary[:100]}") + else: + results.append(heading) + return results + + def format_result(self, result: Optional[Dict]) -> str: + if result is None: + return "Section not found." + return f"Section: {result['heading']}\nSummary: {result['summary']}\nLength: {result['content_length']} chars" diff --git a/src/agent/tools/text.py b/src/agent/tools/text.py new file mode 100644 index 00000000..0ebc66d1 --- /dev/null +++ b/src/agent/tools/text.py @@ -0,0 +1,60 @@ +import re +from pathlib import Path +from typing import List, Optional +from dataclasses import dataclass + +@dataclass +class GrepMatch: + line_number: int + content: str + context_before: List[str] + context_after: List[str] + +class GrepSearch: + """Regex search across raw markdown.""" + + def __init__(self, markdown_path: Path): + self.markdown_path = markdown_path + self._lines: Optional[List[str]] = None + + def _load_lines(self) -> List[str]: + if self._lines is None: + with open(self.markdown_path, "r", encoding="utf-8") as f: + self._lines = f.readlines() + return self._lines + + def search(self, pattern: str, context_lines: int = 2, max_matches: int = 10) -> List[GrepMatch]: + lines = self._load_lines() + try: + compiled = re.compile(pattern, re.IGNORECASE) + except re.error as e: + raise ValueError(f"Invalid regex: {pattern} ({e})") + + matches = [] + for i, line in enumerate(lines): + if compiled.search(line): + start = max(0, i - context_lines) + end = min(len(lines), i + context_lines + 1) + matches.append( + GrepMatch( + line_number=i + 1, + content=line.rstrip(), + context_before=[l.rstrip() for l in lines[start:i]], + context_after=[l.rstrip() for l in lines[i + 1 : end]], + ) + ) + if len(matches) >= max_matches: + break + return matches + + def format_result(self, matches: List[GrepMatch]) -> str: + if not matches: + return "No matches found." + lines = [f"Found {len(matches)} matches:"] + for m in matches: + lines.append(f"\n Line {m.line_number}: {m.content}") + for ctx in m.context_before[-2:]: + lines.append(f" (before) {ctx}") + for ctx in m.context_after[:2]: + lines.append(f" (after) {ctx}") + return "\n".join(lines) diff --git a/src/agent/types.py b/src/agent/types.py new file mode 100644 index 00000000..af197b81 --- /dev/null +++ b/src/agent/types.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from typing import Dict, Any, Optional, List + +@dataclass +class AgentConfig: + reasoning_limit: int = 5 + tool_limit: int = 20 + max_reasoning_tokens: int = 500 + max_generation_tokens: int = 400 + max_context_tokens: int = 8000 + model_path: str = "" + +@dataclass +class AgentStep: + thought: str + tool_name: Optional[str] + tool_args: Dict[str, Any] + context_action: Dict[str, Any] + signal: str + +@dataclass +class ChunkMetadata: + chunk_id: int + score: float + source: str + full_text: str + +@dataclass +class ObservationMetadata: + """Metadata tracking lifecycle of an observation.""" + added_in_step: Optional[int] = None + removed_in_step: Optional[int] = None + replaced_in_step: Optional[int] = None + replaced_with: Optional[str] = None + kept_in_final: bool = False diff --git a/src/main.py b/src/main.py index ceb14df4..74138c33 100644 --- a/src/main.py +++ b/src/main.py @@ -21,8 +21,9 @@ from rich.console import Console from rich.markdown import Markdown -from src.agent.tools import AgentToolkit +from src.agent import AgentToolkit from src.agent.orchestrator import AgentOrchestrator, AgentConfig +from src.agent.logger import AgentLogger ANSWER_NOT_FOUND = "I'm sorry, but I don't have enough information to answer that question." @@ -38,6 +39,13 @@ def parse_args() -> argparse.Namespace: choices=["index", "chat"], help="operation mode: 'index' to build index, 'chat' to query" ) + + # Optional query for chat mode (runs query and exits) + parser.add_argument( + "query", + nargs="?", + help="optional query string for chat mode (runs query and exits)" + ) # Common arguments parser.add_argument( @@ -298,8 +306,8 @@ def run_agent_chat_session(args: argparse.Namespace, cfg: RAGConfig): chunks=chunks, sources=sources, embed_model=cfg.embed_model, - markdown_path="data/book_with_pages.md", - summaries_path="data/section_summaries.json", + markdown_path="data/silberschatz.md", + summaries_path="data/section_summaries.json", # Optional: will be disabled if missing ) agent_config = AgentConfig( @@ -308,13 +316,33 @@ def run_agent_chat_session(args: argparse.Namespace, cfg: RAGConfig): max_generation_tokens=cfg.max_gen_tokens, ) + logger = AgentLogger() + orchestrator = AgentOrchestrator( toolkit=toolkit, model_path=model_path, config=agent_config, + logger=logger, ) + # If query provided, run it once and exit + if args.query: + console.print(f"\n[bold]Available tools:[/bold] {', '.join(toolkit.available_tools)}") + console.print("\n[dim]Investigating...[/dim]") + for event in orchestrator.stream_run(args.query): + if event["type"] == "thought": + console.print(f"[dim]Step {event['step']}: {event['thought']}[/dim]") + elif event["type"] == "tool": + console.print(f"[cyan] → {event['tool_name']}[/cyan]") + elif event["type"] == "answer": + console.print("\n[bold cyan]==================== ANSWER ====================[/bold cyan]\n") + console.print(Markdown(event["answer"])) + console.print("\n[bold cyan]=================================================[/bold cyan]\n") + console.print(f"[dim]Used observations: {event['kept_observations']}[/dim]") + return + print("Initialization complete. Agent mode active.") + print(f"Available tools: {', '.join(toolkit.available_tools)}") print("Type 'exit' or 'quit' to end the session.") while True: @@ -393,6 +421,12 @@ def run_chat_session(args: argparse.Namespace, cfg: RAGConfig): print("Please ensure you have run 'index' mode first.") sys.exit(1) + # If query provided, run it once and exit + if args.query: + ans = get_answer(args.query, cfg, args, logger, console, artifacts=artifacts) + logger.log_generation(ans, {"max_tokens": cfg.max_gen_tokens, "model_path": cfg.gen_model}) + return + print("Initialization complete. You can start asking questions!") print("Type 'exit' or 'quit' to end the session.") while True: diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index c84b3735..fd5305bb 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -248,7 +248,7 @@ def get_tokensmith_answer(question, config, golden_chunks=None): # Check if agent mode is enabled if config.get("use_agent", False): # Use agent orchestrator path - from src.agent.tools import AgentToolkit + from src.agent import AgentToolkit from src.agent.orchestrator import AgentOrchestrator, AgentConfig toolkit = AgentToolkit( From f331dd646abf1e83855ee9c308dbd818ccffc984 Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Thu, 29 Jan 2026 11:41:19 -0500 Subject: [PATCH 13/13] update prompts --- src/agent/orchestrator.py | 13 +++++-- src/agent/prompts.py | 78 ++++++++++++++++++++------------------- 2 files changed, 50 insertions(+), 41 deletions(-) diff --git a/src/agent/orchestrator.py b/src/agent/orchestrator.py index da4de68f..33a79afb 100644 --- a/src/agent/orchestrator.py +++ b/src/agent/orchestrator.py @@ -77,10 +77,9 @@ def _build_prompt(self, question: str, history: List[str]) -> str: content = self.registry.get(ref_id) if content: active_ctx.append(f"[{ref_id}]\n{content}") - if "Content from chunks" in content or "Chunk " in content: - # Simple heuristic to track seen chunks - matches = re.findall(r"Chunk (\d+)", content) - read_chunk_ids.update(int(c) for c in matches) + # Track seen chunks (matches "Chunk 123" or "chunk_id=123") + matches = re.findall(r"(?:Chunk|chunk_id=)\s*(\d+)", content) + read_chunk_ids.update(int(c) for c in matches) full_context = "\n\n".join(active_ctx) if active_ctx else "No observations yet." @@ -182,6 +181,12 @@ def stream_run(self, question: str) -> Generator[Dict[str, Any], None, None]: except ContextBudgetExceeded: history.append("System: Context full. Discard irrelevant observations.") + # Force Read check + if step.tool_name == "search_index": + search_count = sum(1 for h in history if "Tool: search_index" in h) + if search_count >= 2: + history.append("System: You have enough search results. You MUST now use `read_content` to read a chunk. Do not search again.") + if not keep_ids: keep_ids = self.registry.list_ids() diff --git a/src/agent/prompts.py b/src/agent/prompts.py index befcafbc..40e175b8 100644 --- a/src/agent/prompts.py +++ b/src/agent/prompts.py @@ -1,33 +1,41 @@ -AGENT_SYSTEM_PROMPT = """You are an investigative agent that retrieves information to answer questions about databases and SQL. +AGENT_SYSTEM_PROMPT = """You are an investigative agent. -You work in a loop: think → use a tool → observe → repeat until ready. +## WORKFLOW +1. **SEARCH**: Find relevant chunk IDs. +2. **READ**: You **MUST** read the content of relevant chunks found in search. Search previews are truncated. +3. **ANSWER**: When you have read sufficient content. {tool_descriptions} -## Output Format (strict JSON) +## EXAMPLE +Thought: I found chunk 42 which looks relevant. I need to read it. ```json {{ - "thought": "Your reasoning about what information you need", - "tool_name": "name_of_tool or null if done", - "tool_args": {{"arg1": "value1"}}, - "context_action": {{ - "keep": ["obs_1", "obs_3"], - "discard": ["obs_2"] - }}, - "signal": "continue or finish" + "thought": "Reading chunk 42 to get details on decomposition.", + "tool_name": "read_content", + "tool_args": {{"target_chunk_id": 42}}, + "context_action": {{"keep": ["obs_1"], "discard": []}}, + "signal": "continue" }} ``` -## Critical Rules -1. **Investigate Deeply**: Search results are just previews. You MUST read actual chunk content using `read_content`. -2. **Read Content**: Never finish with only search results. Read at least 2-3 chunks. -3. **Explore**: - - Check chunks from different search results. - - Use `read_content` with relative_start/end to see surrounding text. - - Use `grep_text` if you need exact keyword matches in the full text. -4. **Finish Conditions**: - - You have read actual chunk content (not just search previews). - - The information is sufficient to answer the question. +## OUTPUT FORMAT (Strict JSON) +```json +{{ + "thought": "Reasoning here...", + "tool_name": "tool_name_here", + "tool_args": {{ "query": "..." }}, + "context_action": {{ "keep": ["obs_id"], "discard": [] }}, + "signal": "continue" | "finish" +}} +``` + +## RULES +1. **Read After Search**: If you just searched, your NEXT step MUST be `read_content`. +2. **One Chunk at a Time**: The `read_content` tool accepts **ONLY ONE** integer for `target_chunk_id`. Do NOT try to read multiple chunks in one step. +3. **Manage Context**: Keep only relevant observations. Discard search results once you've read the chunks. +4. **No Loops**: Do not search for the same terms repeatedly. +5. **Finish**: Use `signal: "finish"` only when you have **read** the answer in the text. Current observations: {observation_ids} Budget: {budget_status} @@ -35,37 +43,33 @@ INVESTIGATION_PROMPT_TEMPLATE = """<|im_start|>system {system} +<|im_end|> +<|im_start|>user Question: {question} === FULL ACTIVE CONTEXT (all information you currently have) === {full_context} === SUMMARY === -{read_chunks_str} +Read Chunks: {read_chunks_str} Observations lifecycle: {lifecycle_str} -=== RECENT STEPS === +=== RECENT STEPS (Do not repeat these) === {history_text} - -What's your next step? -- You must read actual chunk content using read_content before finishing. - -RESPONSE FORMAT (JSON ONLY): -{{ - "thought": "reasoning...", - "tool_name": "tool_name", - "tool_args": {{ "arg": "value" }}, - "signal": "continue" -}} <|im_end|> <|im_start|>assistant """ -SYNTHESIS_PROMPT = """Answer the question based on the following context. - +SYNTHESIS_PROMPT = """<|im_start|>system +You are a helpful assistant. Synthesize the provided context to answer the user's question. +If the context contains only search previews, do your best but mention that you might need to read more. +Write a clear, cohesive answer. +<|im_end|> +<|im_start|>user Context: {context} Question: {question} - +<|im_end|> +<|im_start|>assistant Answer:"""