diff --git a/.gitignore b/.gitignore index 96a52ddb..5369814a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,9 @@ # RajShah-1 gitignore code-dump.txt +nohup.out +out.log +ideas.md +index/ # --- Python --- __pycache__/ diff --git a/config/config.yaml b/config/config.yaml index 0905f546..41c20466 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -12,4 +12,9 @@ chunk_overlap: 200 use_hyde: false hyde_max_tokens: 300 use_indexed_chunks: false -rerank_mode: "cross_encoder" \ No newline at end of file +rerank_mode: "cross_encoder" + +# Agent mode settings +use_agent: true +agent_reasoning_limit: 5 +agent_tool_limit: 20 diff --git a/src/agent/__init__.py b/src/agent/__init__.py new file mode 100644 index 00000000..f3e6a216 --- /dev/null +++ b/src/agent/__init__.py @@ -0,0 +1,7 @@ +from src.agent.context import ContextRegistry +from src.agent.orchestrator import AgentOrchestrator +from src.agent.logger import AgentLogger +from src.agent.toolkit import AgentToolkit + +__all__ = ["ContextRegistry", "AgentOrchestrator", "AgentLogger", "AgentToolkit"] + diff --git a/src/agent/context.py b/src/agent/context.py new file mode 100644 index 00000000..e61a508a --- /dev/null +++ b/src/agent/context.py @@ -0,0 +1,117 @@ +from typing import Dict, List, Optional, Any +from src.agent.types import ObservationMetadata + +class ContextBudgetExceeded(Exception): + """Raised when adding an observation would exceed the max context budget.""" + pass + +class ContextRegistry: + """ + Keyed registry for agent observations. + Enforces a rough token budget (approx 3.5 chars per token). + """ + + def __init__(self, max_tokens: int = 8000): + self._observations: Dict[str, str] = {} + self._metadata: Dict[str, ObservationMetadata] = {} + self._counter: int = 0 + self._max_tokens = max_tokens + self._current_chars = 0 + + @property + def current_tokens(self) -> int: + return int(self._current_chars / 3.5) + + @property + def status(self) -> Dict[str, Any]: + used = self.current_tokens + return { + "used": used, + "total": self._max_tokens, + "usage_percent": (used / self._max_tokens) * 100 if self._max_tokens > 0 else 0.0, + "count": len(self._observations) + } + + def _check_budget(self, text: str): + new_tokens = int(len(text) / 3.5) + if (self.current_tokens + new_tokens) > self._max_tokens: + raise ContextBudgetExceeded( + f"Cannot add observation ({new_tokens} tokens). " + f"Registry full: {self.current_tokens}/{self._max_tokens} tokens used." + ) + + def add(self, text: str, step: Optional[int] = None) -> str: + self._check_budget(text) + self._counter += 1 + ref_id = f"obs_{self._counter}" + self._observations[ref_id] = text + self._metadata[ref_id] = ObservationMetadata(added_in_step=step) + self._current_chars += len(text) + return ref_id + + def remove(self, ref_id: str, step: Optional[int] = None) -> bool: + if ref_id in self._observations: + text = self._observations.pop(ref_id) + if ref_id in self._metadata: + self._metadata[ref_id].removed_in_step = step + else: + self._metadata[ref_id] = ObservationMetadata(removed_in_step=step) + self._current_chars -= len(text) + return True + return False + + def replace(self, ref_id: str, new_text: str, step: Optional[int] = None) -> None: + if ref_id not in self._observations: + raise KeyError(f"Observation {ref_id} not found.") + + old_text = self._observations[ref_id] + diff_chars = len(new_text) - len(old_text) + new_total_tokens = int((self._current_chars + diff_chars) / 3.5) + + if new_total_tokens > self._max_tokens: + raise ContextBudgetExceeded( + f"Replacement exceeds budget. Total would be {new_total_tokens}/{self._max_tokens}." + ) + + self._observations[ref_id] = new_text + if ref_id in self._metadata: + self._metadata[ref_id].replaced_in_step = step + self._metadata[ref_id].replaced_with = ref_id + self._current_chars += diff_chars + + def get(self, ref_id: str) -> Optional[str]: + return self._observations.get(ref_id) + + def list_ids(self) -> List[str]: + return list(self._observations.keys()) + + def clear(self) -> None: + self._observations.clear() + self._metadata.clear() + self._counter = 0 + self._current_chars = 0 + + def __len__(self) -> int: + return len(self._observations) + + def get_all_metadata(self) -> Dict[str, Dict[str, Any]]: + result = {} + all_ref_ids = set(self._observations.keys()) | set(self._metadata.keys()) + + for ref_id in all_ref_ids: + meta = self._metadata.get(ref_id, ObservationMetadata()) + lifecycle = [] + if meta.added_in_step is not None: lifecycle.append(f"added-in-step-{meta.added_in_step}") + if meta.removed_in_step is not None: lifecycle.append(f"removed-in-step-{meta.removed_in_step}") + if meta.replaced_in_step is not None: lifecycle.append(f"replaced-in-step-{meta.replaced_in_step}") + if meta.kept_in_final: lifecycle.append("kept-in-final-content") + + result[ref_id] = { + "content": self._observations.get(ref_id), + "lifecycle": "; ".join(lifecycle) if lifecycle else "no-events", + "added_in_step": meta.added_in_step, + "removed_in_step": meta.removed_in_step, + "replaced_in_step": meta.replaced_in_step, + "kept_in_final": meta.kept_in_final, + } + return result diff --git a/src/agent/generate_summaries.py b/src/agent/generate_summaries.py new file mode 100644 index 00000000..50bbc6ea --- /dev/null +++ b/src/agent/generate_summaries.py @@ -0,0 +1,47 @@ +import json +from pathlib import Path +from src.agent.summarizer import ThinkingSummarizer +from src.agent.summary_db import init_db, save_section_to_db + +_PROJECT_ROOT = Path(__file__).parent.parent.parent +EXTRACTED_PATH = _PROJECT_ROOT / "data" / "extracted_sections.json" +NAV_DB_PATH = _PROJECT_ROOT / "data" / "nav_index.sqlite3" +MODEL_PATH = str(_PROJECT_ROOT / "models" / "Qwen3-30B-A3B-Q6_K.gguf") + +def process_all(): + if not EXTRACTED_PATH.exists(): + print("No extracted sections found.") + return + + with open(EXTRACTED_PATH) as f: + sections = json.load(f) + + summarizer = ThinkingSummarizer(MODEL_PATH) + init_db(NAV_DB_PATH) + + for i, sec in enumerate(sections): + print(f"Processing section {i+1}/{len(sections)}...") + text = sec.get("content", "") + if not text: continue + + # Dense summary + dense = summarizer.summarize_recursive(text) + + # One liner + one_line = summarizer.one_line(text) + + # Paragraphs + paras = [] + for p_idx, para in enumerate(text.split("\n\n")): + if not para.strip(): continue + p_summary = summarizer.one_line(para) + paras.append({ + "index": p_idx, + "text": para, + "summary": p_summary + }) + + save_section_to_db(NAV_DB_PATH, i+1, sec.get("heading", ""), dense, one_line, paras) + +if __name__ == "__main__": + process_all() diff --git a/src/agent/llm.py b/src/agent/llm.py new file mode 100644 index 00000000..be0502a4 --- /dev/null +++ b/src/agent/llm.py @@ -0,0 +1,19 @@ +from typing import List, Optional +from src.generator import get_llama_model + +class AgentLLM: + def __init__(self, model_path: str): + self.model_path = model_path + + def completion(self, prompt: str, max_tokens: int = 500, temperature: float = 0.1, stop: Optional[List[str]] = None) -> str: + model = get_llama_model(self.model_path) + if stop is None: + stop = ["<|im_end|>"] + + result = model.create_completion( + prompt, + max_tokens=max_tokens, + temperature=temperature, + stop=stop, + ) + return result["choices"][0]["text"] diff --git a/src/agent/logger.py b/src/agent/logger.py new file mode 100644 index 00000000..2f2523e8 --- /dev/null +++ b/src/agent/logger.py @@ -0,0 +1,42 @@ +"""Minimal logging for agent pipeline.""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional + + +class AgentLogger: + """Logs agent interactions to JSONL file.""" + + def __init__(self, session_id: Optional[str] = None): + self.session_id = session_id or datetime.now().strftime("%Y%m%d_%H%M%S") + self.logs_dir = Path("logs") / "agent" + self.logs_dir.mkdir(parents=True, exist_ok=True) + self.log_file = self.logs_dir / f"agent_{self.session_id}.jsonl" + + def _write(self, data: Dict[str, Any]) -> None: + data["ts"] = datetime.now().isoformat() + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(data, ensure_ascii=False) + "\n") + + def log_step(self, step: int, thought: str, tool_name: Optional[str], tool_args: Dict[str, Any], result: Optional[str], success: bool) -> None: + """Log a reasoning step with full thought (no truncation).""" + self._write({ + "event": "step", + "step": step, + "thought": thought, + "tool_name": tool_name, + "tool_args": tool_args, + "result": result, + "success": success, + }) + + def log_query_complete(self, question: str, answer: str, registry_metadata: Dict[str, Dict[str, Any]]) -> None: + """Log query completion with full registry lifecycle.""" + self._write({ + "event": "query_complete", + "question": question, + "answer": answer, + "registry_entries": registry_metadata, + }) diff --git a/src/agent/orchestrator.py b/src/agent/orchestrator.py new file mode 100644 index 00000000..33a79afb --- /dev/null +++ b/src/agent/orchestrator.py @@ -0,0 +1,223 @@ +import json +import re +from typing import Optional, List, Any, Generator, Dict + +from src.agent.types import AgentConfig, AgentStep +from src.agent.context import ContextRegistry, ContextBudgetExceeded +from src.agent.toolkit import AgentToolkit +from src.agent.llm import AgentLLM +from src.agent.prompts import INVESTIGATION_PROMPT_TEMPLATE, SYNTHESIS_PROMPT, AGENT_SYSTEM_PROMPT + +class AgentOrchestrator: + def __init__(self, toolkit: AgentToolkit, model_path: str, config: Optional[AgentConfig] = None, logger: Optional[Any] = None): + self.toolkit = toolkit + self.config = config or AgentConfig(model_path=model_path) + self.llm = AgentLLM(model_path) + self.registry = ContextRegistry(max_tokens=self.config.max_context_tokens) + self.logger = logger + + def _parse_step(self, text: str) -> Optional[AgentStep]: + text = text.strip() + match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL) or re.search(r"\{.*\}", text, re.DOTALL) + if not match: return None + + try: + data = json.loads(match.group(1 if match.lastindex else 0)) + + # Handle hallucinated "next_step": "tool(args)" format + if "next_step" in data and not data.get("tool_name"): + call_str = data["next_step"] + # Parse tool(arg=val) + m_call = re.match(r"(\w+)\((.*)\)", call_str) + if m_call: + tool_name = m_call.group(1) + args_str = m_call.group(2) + args = {} + # Simple arg parser for key=value or key="value" + # This is brittle but handles the example seen + for arg_pair in args_str.split(","): + if "=" in arg_pair: + k, v = arg_pair.split("=", 1) + k = k.strip() + v = v.strip().strip("'").strip('"') + # Try to convert to int/float + try: + if "." in v: v = float(v) + else: v = int(v) + except ValueError: + pass + args[k] = v + + return AgentStep( + thought=data.get("thought", f"Decided to call {tool_name}"), + tool_name=tool_name, + tool_args=args, + context_action={}, + signal="continue" + ) + + return AgentStep( + thought=data.get("thought", "Recovering from simplified output..."), + tool_name=data.get("tool_name"), + tool_args=data.get("tool_args", {}), + context_action=data.get("context_action", {}), + signal=data.get("signal", "continue"), + ) + except Exception as e: + print(f"JSON Parse Error: {e}") + return None + + def _build_prompt(self, question: str, history: List[str]) -> str: + obs_ids = self.registry.list_ids() + + # Format active context + active_ctx = [] + read_chunk_ids = set() + for ref_id in obs_ids: + content = self.registry.get(ref_id) + if content: + active_ctx.append(f"[{ref_id}]\n{content}") + # Track seen chunks (matches "Chunk 123" or "chunk_id=123") + matches = re.findall(r"(?:Chunk|chunk_id=)\s*(\d+)", content) + read_chunk_ids.update(int(c) for c in matches) + + full_context = "\n\n".join(active_ctx) if active_ctx else "No observations yet." + + # Lifecycle metadata + meta = self.registry.get_all_metadata() + lifecycle = [f"{k}: {v['lifecycle']}" for k, v in meta.items() if v['lifecycle'] != "no-events"] + + status = self.registry.status + system_text = AGENT_SYSTEM_PROMPT.format( + tool_descriptions=self.toolkit.get_tool_descriptions(), + observation_ids=str(obs_ids), + budget_status=f"{status['used']}/{status['total']} tokens" + ) + + return INVESTIGATION_PROMPT_TEMPLATE.format( + system=system_text, + question=question, + full_context=full_context, + read_chunks_str=str(sorted(list(read_chunk_ids))), + lifecycle_str="\n".join(lifecycle) if lifecycle else "None", + history_text="\n".join(history[-10:]) or "None" + ) + + def stream_run(self, question: str) -> Generator[Dict[str, Any], None, None]: + """ + Run investigation and synthesis, yielding events. + Events: + - {"type": "thought", "step": int, "thought": str} + - {"type": "tool", "tool_name": str, "tool_args": dict} + - {"type": "answer", "answer": str, "kept_observations": List[str]} + """ + self.registry.clear() + + # Seed initial search + res, _ = self.toolkit.get_initial_context(question) + self.registry.add(f"Initial search: {res}", step=0) + + history = [] + steps = 0 + keep_ids = [] + + # --- Investigation Phase --- + while steps < self.config.reasoning_limit: + steps += 1 + # print(f"--- Step {steps} ---") + prompt = self._build_prompt(question, history) + + response = self.llm.completion(prompt, max_tokens=self.config.max_reasoning_tokens) + print(f"\n[DEBUG RAW RESPONSE]\n{response}\n[END DEBUG]\n") + step = self._parse_step(response) + + if not step: + # print("Failed to parse response") + history.append(f"Step {steps}: [Use stricter JSON format]") + continue + + # Yield thought event + yield { + "type": "thought", + "step": steps, + "thought": step.thought + } + + # print(f"Thought: {step.thought}") + # print(f"Tool: {step.tool_name}") + + history.append(f"Step {steps}: {step.thought} (Tool: {step.tool_name})") + if self.logger: + self.logger.log_step(steps, step.thought, step.tool_name, step.tool_args, None, True) + + # Handle context actions + if step.context_action: + for ref_id in step.context_action.get("discard", []): + self.registry.remove(ref_id, step=steps) + + # Finish logic + if step.signal == "finish" or not step.tool_name: + # Basic check: do we have any content? + has_content = any("Chunk" in (self.registry.get(oid) or "") for oid in self.registry.list_ids()) + if has_content or steps >= self.config.reasoning_limit: + keep_ids = step.context_action.get("keep", self.registry.list_ids()) + break + + # If trying to finish without content, force one more step + history.append("System: You have search results but no full content. Read chunks before finishing.") + continue + + # Yield tool event + yield { + "type": "tool", + "tool_name": step.tool_name, + "tool_args": step.tool_args + } + + # Execute tool + res, success = self.toolkit.execute(step.tool_name, step.tool_args) + try: + self.registry.add(f"Tool {step.tool_name} result:\n{res}", step=steps) + except ContextBudgetExceeded: + history.append("System: Context full. Discard irrelevant observations.") + + # Force Read check + if step.tool_name == "search_index": + search_count = sum(1 for h in history if "Tool: search_index" in h) + if search_count >= 2: + history.append("System: You have enough search results. You MUST now use `read_content` to read a chunk. Do not search again.") + + if not keep_ids: + keep_ids = self.registry.list_ids() + + # --- Synthesis Phase --- + parts = [] + for ref_id in keep_ids: + c = self.registry.get(ref_id) + if c: parts.append(f"[{ref_id}]\n{c}") + final_context = "\n\n".join(parts) + + prompt = SYNTHESIS_PROMPT.format(context=final_context, question=question) + answer_text = self.llm.completion(prompt, max_tokens=self.config.max_generation_tokens) + + yield { + "type": "answer", + "answer": answer_text, + "kept_observations": keep_ids + } + + def investigate(self, question: str) -> List[str]: + # Legacy/Support method wrapping stream_run + last_ids = [] + for event in self.stream_run(question): + if event["type"] == "answer": + last_ids = event["kept_observations"] + return last_ids + + def answer(self, question: str) -> str: + # Legacy/Support method wrapping stream_run + ans = "" + for event in self.stream_run(question): + if event["type"] == "answer": + ans = event["answer"] + return ans diff --git a/src/agent/prompts.py b/src/agent/prompts.py new file mode 100644 index 00000000..40e175b8 --- /dev/null +++ b/src/agent/prompts.py @@ -0,0 +1,75 @@ +AGENT_SYSTEM_PROMPT = """You are an investigative agent. + +## WORKFLOW +1. **SEARCH**: Find relevant chunk IDs. +2. **READ**: You **MUST** read the content of relevant chunks found in search. Search previews are truncated. +3. **ANSWER**: When you have read sufficient content. + +{tool_descriptions} + +## EXAMPLE +Thought: I found chunk 42 which looks relevant. I need to read it. +```json +{{ + "thought": "Reading chunk 42 to get details on decomposition.", + "tool_name": "read_content", + "tool_args": {{"target_chunk_id": 42}}, + "context_action": {{"keep": ["obs_1"], "discard": []}}, + "signal": "continue" +}} +``` + +## OUTPUT FORMAT (Strict JSON) +```json +{{ + "thought": "Reasoning here...", + "tool_name": "tool_name_here", + "tool_args": {{ "query": "..." }}, + "context_action": {{ "keep": ["obs_id"], "discard": [] }}, + "signal": "continue" | "finish" +}} +``` + +## RULES +1. **Read After Search**: If you just searched, your NEXT step MUST be `read_content`. +2. **One Chunk at a Time**: The `read_content` tool accepts **ONLY ONE** integer for `target_chunk_id`. Do NOT try to read multiple chunks in one step. +3. **Manage Context**: Keep only relevant observations. Discard search results once you've read the chunks. +4. **No Loops**: Do not search for the same terms repeatedly. +5. **Finish**: Use `signal: "finish"` only when you have **read** the answer in the text. + +Current observations: {observation_ids} +Budget: {budget_status} +""" + +INVESTIGATION_PROMPT_TEMPLATE = """<|im_start|>system +{system} +<|im_end|> +<|im_start|>user +Question: {question} + +=== FULL ACTIVE CONTEXT (all information you currently have) === +{full_context} + +=== SUMMARY === +Read Chunks: {read_chunks_str} +Observations lifecycle: {lifecycle_str} + +=== RECENT STEPS (Do not repeat these) === +{history_text} +<|im_end|> +<|im_start|>assistant +""" + +SYNTHESIS_PROMPT = """<|im_start|>system +You are a helpful assistant. Synthesize the provided context to answer the user's question. +If the context contains only search previews, do your best but mention that you might need to read more. +Write a clear, cohesive answer. +<|im_end|> +<|im_start|>user +Context: +{context} + +Question: {question} +<|im_end|> +<|im_start|>assistant +Answer:""" diff --git a/src/agent/summarizer.py b/src/agent/summarizer.py new file mode 100644 index 00000000..52a35cc4 --- /dev/null +++ b/src/agent/summarizer.py @@ -0,0 +1,70 @@ +import re +from typing import Optional +from llama_cpp import Llama + +class ThinkingSummarizer: + """Large-model summarizer with thinking tag cleaning.""" + + def __init__(self, model_path: str, n_ctx: int = 40960): + self.llm = Llama( + model_path=model_path, + n_ctx=n_ctx, + n_gpu_layers=-1, + tensor_split=[24, 24], + verbose=False, + n_batch=512, + ) + + def clean_thinking_tokens(self, text: str) -> str: + """Strip / blocks while keeping the final answer.""" + if not text: return "" + flags = re.DOTALL | re.IGNORECASE + + # Keep everything after last closing tag + for tag in ["think", "thinking", "reasoning"]: + m = re.search(rf'\s*(.*)$', text, flags=flags) + if m: text = m.group(1) + text = re.sub(rf']*>', '', text, flags=flags) + + text = re.sub(r'\n\s*\n+', '\n\n', text) + return text.strip() + + def generate_update(self, current_summary: str, new_text: str, budget: int = 500) -> str: + prompt = f"""<|im_start|>system +You are an expert technical summarizer maintaining a dense running summary. +<|im_end|> +<|im_start|>user +Current Summary: {current_summary or "(None)"} +New Content: {new_text} +Task: Update summary. +Constraints: length < {budget} tokens. No thinking tags. +<|im_end|> +<|im_start|>assistant +""" + output = self.llm.create_completion(prompt, max_tokens=budget + 100, temperature=0.3, stop=["<|im_end|>"]) + return self.clean_thinking_tokens(output["choices"][0]["text"]) + + def summarize_recursive(self, text: str, current_summary: str = "", budget: int = 500) -> str: + # Simple recursion based on length + est_tokens = (len(text) + len(current_summary)) / 3.5 + if est_tokens < (40960 - budget - 1000): + return self.generate_update(current_summary, text, budget) + + mid = len(text) // 2 + # Simple split + part1, part2 = text[:mid], text[mid:] + updated = self.summarize_recursive(part1, current_summary, budget) + return self.summarize_recursive(part2, updated, budget) + + def one_line(self, text: str) -> str: + if not text.strip(): raise ValueError("Empty text") + prompt = f"""<|im_start|>system +Write one concise sentence description (< 200 chars). +<|im_end|> +<|im_start|>user +Text: {text} +<|im_end|> +<|im_start|>assistant +""" + output = self.llm.create_completion(prompt, max_tokens=64, temperature=0.3, stop=["<|im_end|>"]) + return self.clean_thinking_tokens(output["choices"][0]["text"]) diff --git a/src/agent/summary_db.py b/src/agent/summary_db.py new file mode 100644 index 00000000..1cc99239 --- /dev/null +++ b/src/agent/summary_db.py @@ -0,0 +1,35 @@ +import sqlite3 +from pathlib import Path +from typing import List, Dict + +def init_db(db_path: Path) -> None: + with sqlite3.connect(db_path) as conn: + conn.executescript(""" + DROP TABLE IF EXISTS paragraphs; + DROP TABLE IF EXISTS sections; + CREATE TABLE sections ( + id INTEGER PRIMARY KEY, heading TEXT, section_summary TEXT, + one_line_summary TEXT, prev_section_id INTEGER, next_section_id INTEGER, + num_paragraphs INTEGER, content_length INTEGER + ); + CREATE TABLE paragraphs ( + id INTEGER PRIMARY KEY, section_id INTEGER, para_index INTEGER, + one_line_summary TEXT, raw_text TEXT, + FOREIGN KEY(section_id) REFERENCES sections(id) + ); + CREATE INDEX idx_paragraphs_section ON paragraphs(section_id, para_index); + """) + +def save_section_to_db(db_path: Path, section_id: int, heading: str, summary: str, one_line: str, paras: List[Dict]): + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + cur.execute( + "INSERT INTO sections (id, heading, section_summary, one_line_summary, content_length) VALUES (?, ?, ?, ?, ?)", + (section_id, heading, summary, one_line, 0) # simplified + ) + for p in paras: + cur.execute( + "INSERT INTO paragraphs (section_id, para_index, one_line_summary, raw_text) VALUES (?, ?, ?, ?)", + (section_id, p['index'], p['summary'], p['text']) + ) + conn.commit() diff --git a/src/agent/toolkit.py b/src/agent/toolkit.py new file mode 100644 index 00000000..fbc14c75 --- /dev/null +++ b/src/agent/toolkit.py @@ -0,0 +1,94 @@ +from typing import List, Dict, Tuple, Optional +from pathlib import Path +import faiss + +from src.agent.tools.search import IndexScout +from src.agent.tools.read import NavigationalReader +from src.agent.tools.text import GrepSearch +from src.agent.tools.sections import SectionSummarizer + +class AgentToolkit: + """Container for agent tools.""" + + def __init__( + self, + faiss_index: faiss.Index, + chunks: List[str], + sources: List[str], + embed_model: str, + markdown_path: Optional[str] = None, + summaries_path: Optional[str] = None, + ): + self.index_scout = IndexScout(faiss_index, chunks, sources, embed_model) + self.reader = NavigationalReader(chunks, sources) + self.grep: Optional[GrepSearch] = None + self.summarizer: Optional[SectionSummarizer] = None + + if markdown_path and Path(markdown_path).exists(): + self.grep = GrepSearch(Path(markdown_path)) + + if summaries_path and Path(summaries_path).exists(): + self.summarizer = SectionSummarizer(Path(summaries_path)) + + self._available_tools = ["search_index", "read_content"] + if self.grep: self._available_tools.append("grep_text") + if self.summarizer: self._available_tools.extend(["get_section_summary", "list_sections"]) + + @property + def available_tools(self) -> List[str]: + return self._available_tools + + def get_tool_descriptions(self) -> str: + # Construct dynamic tool descriptions based on availability + base = """ +Available Tools: +- `search_index(query="...")`: Semantic search. Returns relevant chunk IDs vs preview text. +- `read_content(target_chunk_id=123, relative_start=-1, relative_end=1)`: Read full text of chunks. Use relative_start/end to read surrounding context. +""" + if self.grep: + base += '- `grep_text(pattern="regex")`: Search for exact patterns in the full text.\n' + if self.summarizer: + base += '- `list_sections(limit=30)`: List available document sections.\n' + base += '- `get_section_summary(section_name="...")`: Get summary of a specific section.\n' + return base + + def get_initial_context(self, question: str, top_k: int = 5) -> Tuple[str, bool]: + """Helper to get initial search results.""" + results = self.index_scout.search(question, top_k=top_k) + return self.index_scout.format_result(results), True + + def execute(self, tool_name: str, tool_args: Dict) -> Tuple[str, bool]: + if tool_name not in self._available_tools: + return f"Unknown tool '{tool_name}'. Available: {', '.join(self._available_tools)}", False + + try: + if tool_name == "search_index": + return self.index_scout.format_result( + self.index_scout.search(tool_args.get("query"), top_k=tool_args.get("top_k", 10)) + ), True + + elif tool_name == "read_content": + text, chunk_ids = self.reader.read( + tool_args.get("target_chunk_id"), + tool_args.get("relative_start", 0), + tool_args.get("relative_end", 0), + ) + return self.reader.format_result(text, chunk_ids), True + + elif tool_name == "grep_text" and self.grep: + return self.grep.format_result( + self.grep.search(tool_args.get("pattern"), tool_args.get("context_lines", 2)) + ), True + + elif tool_name == "list_sections" and self.summarizer: + return "\n".join(self.summarizer.list_sections(tool_args.get("limit", 30))), True + + elif tool_name == "get_section_summary" and self.summarizer: + return self.summarizer.format_result( + self.summarizer.get_summary(tool_args.get("section_name", "")) + ), True + + except Exception as e: + return f"Tool execution failed: {str(e)}", False + + return "Tool not available or arguments invalid.", False diff --git a/src/agent/tools/__init__.py b/src/agent/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agent/tools/read.py b/src/agent/tools/read.py new file mode 100644 index 00000000..afd2d295 --- /dev/null +++ b/src/agent/tools/read.py @@ -0,0 +1,28 @@ +from typing import List, Tuple + +class NavigationalReader: + """Read chunks with relative offset navigation.""" + + def __init__(self, chunks: List[str], sources: List[str]): + self.chunks = chunks + self.sources = sources + + def read(self, target_chunk_id: int, relative_start: int = 0, relative_end: int = 0) -> Tuple[str, List[int]]: + start_idx = max(0, target_chunk_id + relative_start) + end_idx = min(len(self.chunks), target_chunk_id + relative_end + 1) + + if start_idx >= len(self.chunks) or end_idx <= start_idx: + return "", [] + + chunk_ids = list(range(start_idx, end_idx)) + texts = [] + for cid in chunk_ids: + src = self.sources[cid] if cid < len(self.sources) else "unknown" + texts.append(f"--- Chunk {cid} (source: {src}) ---\n{self.chunks[cid]}") + + return "\n\n".join(texts), chunk_ids + + def format_result(self, text: str, chunk_ids: List[int]) -> str: + if not text: + return "No content found for specified range." + return f"Content from chunks {chunk_ids}:\n\n{text}" diff --git a/src/agent/tools/search.py b/src/agent/tools/search.py new file mode 100644 index 00000000..9d0e36d0 --- /dev/null +++ b/src/agent/tools/search.py @@ -0,0 +1,42 @@ +from typing import List +import faiss +from src.agent.types import ChunkMetadata +from src.retriever import _get_embedder + +class IndexScout: + """Semantic search returning structured metadata.""" + + def __init__(self, faiss_index: faiss.Index, chunks: List[str], sources: List[str], embed_model: str): + self.faiss_index = faiss_index + self.chunks = chunks + self.sources = sources + self.embedder = _get_embedder(embed_model) + + def search(self, query: str, top_k: int = 10) -> List[ChunkMetadata]: + q_vec = self.embedder.encode([query]).astype("float32") + distances, indices = self.faiss_index.search(q_vec, top_k) + + results = [] + for idx, dist in zip(indices[0], distances[0]): + if idx < 0 or idx >= len(self.chunks): + continue + score = 1.0 / (1.0 + float(dist)) + preview = self.chunks[idx] + results.append( + ChunkMetadata( + chunk_id=int(idx), + score=score, + source=self.sources[idx] if idx < len(self.sources) else "unknown", + full_text=preview, + ) + ) + return results + + def format_result(self, results: List[ChunkMetadata]) -> str: + if not results: + return "No results found." + lines = ["Search results (use chunk_id with read_content):"] + for i, r in enumerate(results): + lines.append(f" [{i}] chunk_id={r.chunk_id} score={r.score:.3f} source={r.source}") + lines.append(f" preview: {r.full_text}") + return "\n".join(lines) diff --git a/src/agent/tools/sections.py b/src/agent/tools/sections.py new file mode 100644 index 00000000..03adf4ca --- /dev/null +++ b/src/agent/tools/sections.py @@ -0,0 +1,46 @@ +import json +from pathlib import Path +from typing import List, Dict, Optional + +class SectionSummarizer: + """Retrieve section summaries from pre-generated file.""" + + def __init__(self, summaries_path: Path): + self.summaries_path = summaries_path + self._summaries: Optional[List[Dict]] = None + + def _load_summaries(self) -> List[Dict]: + if self._summaries is None: + with open(self.summaries_path, "r", encoding="utf-8") as f: + self._summaries = json.load(f) + return self._summaries + + def get_summary(self, section_name: str) -> Optional[Dict]: + summaries = self._load_summaries() + section_lower = section_name.lower() + for s in summaries: + heading = s.get("heading", "") + if section_lower in heading.lower(): + return { + "heading": heading, + "summary": s.get("summary", ""), + "content_length": s.get("content_length", 0), + } + return None + + def list_sections(self, limit: int = 30) -> List[str]: + summaries = self._load_summaries() + results = [] + for s in summaries[:limit]: + heading = s.get("heading", "Untitled") + summary = s.get("summary", "") + if summary: + results.append(f"{heading}: {summary[:100]}") + else: + results.append(heading) + return results + + def format_result(self, result: Optional[Dict]) -> str: + if result is None: + return "Section not found." + return f"Section: {result['heading']}\nSummary: {result['summary']}\nLength: {result['content_length']} chars" diff --git a/src/agent/tools/text.py b/src/agent/tools/text.py new file mode 100644 index 00000000..0ebc66d1 --- /dev/null +++ b/src/agent/tools/text.py @@ -0,0 +1,60 @@ +import re +from pathlib import Path +from typing import List, Optional +from dataclasses import dataclass + +@dataclass +class GrepMatch: + line_number: int + content: str + context_before: List[str] + context_after: List[str] + +class GrepSearch: + """Regex search across raw markdown.""" + + def __init__(self, markdown_path: Path): + self.markdown_path = markdown_path + self._lines: Optional[List[str]] = None + + def _load_lines(self) -> List[str]: + if self._lines is None: + with open(self.markdown_path, "r", encoding="utf-8") as f: + self._lines = f.readlines() + return self._lines + + def search(self, pattern: str, context_lines: int = 2, max_matches: int = 10) -> List[GrepMatch]: + lines = self._load_lines() + try: + compiled = re.compile(pattern, re.IGNORECASE) + except re.error as e: + raise ValueError(f"Invalid regex: {pattern} ({e})") + + matches = [] + for i, line in enumerate(lines): + if compiled.search(line): + start = max(0, i - context_lines) + end = min(len(lines), i + context_lines + 1) + matches.append( + GrepMatch( + line_number=i + 1, + content=line.rstrip(), + context_before=[l.rstrip() for l in lines[start:i]], + context_after=[l.rstrip() for l in lines[i + 1 : end]], + ) + ) + if len(matches) >= max_matches: + break + return matches + + def format_result(self, matches: List[GrepMatch]) -> str: + if not matches: + return "No matches found." + lines = [f"Found {len(matches)} matches:"] + for m in matches: + lines.append(f"\n Line {m.line_number}: {m.content}") + for ctx in m.context_before[-2:]: + lines.append(f" (before) {ctx}") + for ctx in m.context_after[:2]: + lines.append(f" (after) {ctx}") + return "\n".join(lines) diff --git a/src/agent/types.py b/src/agent/types.py new file mode 100644 index 00000000..af197b81 --- /dev/null +++ b/src/agent/types.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from typing import Dict, Any, Optional, List + +@dataclass +class AgentConfig: + reasoning_limit: int = 5 + tool_limit: int = 20 + max_reasoning_tokens: int = 500 + max_generation_tokens: int = 400 + max_context_tokens: int = 8000 + model_path: str = "" + +@dataclass +class AgentStep: + thought: str + tool_name: Optional[str] + tool_args: Dict[str, Any] + context_action: Dict[str, Any] + signal: str + +@dataclass +class ChunkMetadata: + chunk_id: int + score: float + source: str + full_text: str + +@dataclass +class ObservationMetadata: + """Metadata tracking lifecycle of an observation.""" + added_in_step: Optional[int] = None + removed_in_step: Optional[int] = None + replaced_in_step: Optional[int] = None + replaced_with: Optional[str] = None + kept_in_final: bool = False diff --git a/src/config.py b/src/config.py index af173cd1..e0976e6f 100644 --- a/src/config.py +++ b/src/config.py @@ -31,6 +31,7 @@ class RAGConfig: # generation max_gen_tokens: int = 400 gen_model: str = "models/qwen2.5-1.5b-instruct-q5_k_m.gguf" + temperature: float = 0.7 # testing system_prompt_mode: str = "baseline" @@ -48,6 +49,11 @@ class RAGConfig: extracted_index_path: os.PathLike = "data/extracted_index.json" page_to_chunk_map_path: os.PathLike = "index/sections/textbook_index_page_to_chunk_map.json" + # agent mode + use_agent: bool = False + agent_reasoning_limit: int = 5 + agent_tool_limit: int = 20 + # ---------- factory + validation ---------- @classmethod def from_yaml(cls, path: os.PathLike) -> RAGConfig: diff --git a/src/generator.py b/src/generator.py index 7fcfea38..e6f933fb 100644 --- a/src/generator.py +++ b/src/generator.py @@ -113,7 +113,7 @@ def format_prompt(chunks, query, max_chunk_chars=400, system_prompt_mode="tutor" _LLM_CACHE = {} -def get_llama_model(model_path: str, n_ctx: int = 4096): +def get_llama_model(model_path: str, n_ctx: int = 16384): if model_path not in _LLM_CACHE: _LLM_CACHE[model_path] = Llama(model_path=model_path, n_ctx=n_ctx, diff --git a/src/main.py b/src/main.py index 998d3a11..74138c33 100644 --- a/src/main.py +++ b/src/main.py @@ -21,6 +21,10 @@ from rich.console import Console from rich.markdown import Markdown +from src.agent import AgentToolkit +from src.agent.orchestrator import AgentOrchestrator, AgentConfig +from src.agent.logger import AgentLogger + ANSWER_NOT_FOUND = "I'm sorry, but I don't have enough information to answer that question." def parse_args() -> argparse.Namespace: @@ -35,6 +39,13 @@ def parse_args() -> argparse.Namespace: choices=["index", "chat"], help="operation mode: 'index' to build index, 'chat' to query" ) + + # Optional query for chat mode (runs query and exits) + parser.add_argument( + "query", + nargs="?", + help="optional query string for chat mode (runs query and exits)" + ) # Common arguments parser.add_argument( @@ -276,6 +287,90 @@ def get_keywords(question: str) -> list: keywords = [word.strip('.,!?()[]') for word in words if word not in stopwords] return keywords + +def run_agent_chat_session(args: argparse.Namespace, cfg: RAGConfig): + """Agent-based chat with dynamic context management.""" + console = Console() + + print("Welcome to Tokensmith (Agent Mode)! Initializing...") + artifacts_dir = cfg.get_artifacts_directory() + faiss_index, bm25_index, chunks, sources, meta = load_artifacts( + artifacts_dir=artifacts_dir, + index_prefix=args.index_prefix + ) + + model_path = args.model_path or cfg.gen_model + + toolkit = AgentToolkit( + faiss_index=faiss_index, + chunks=chunks, + sources=sources, + embed_model=cfg.embed_model, + markdown_path="data/silberschatz.md", + summaries_path="data/section_summaries.json", # Optional: will be disabled if missing + ) + + agent_config = AgentConfig( + reasoning_limit=cfg.agent_reasoning_limit, + tool_limit=cfg.agent_tool_limit, + max_generation_tokens=cfg.max_gen_tokens, + ) + + logger = AgentLogger() + + orchestrator = AgentOrchestrator( + toolkit=toolkit, + model_path=model_path, + config=agent_config, + logger=logger, + ) + + # If query provided, run it once and exit + if args.query: + console.print(f"\n[bold]Available tools:[/bold] {', '.join(toolkit.available_tools)}") + console.print("\n[dim]Investigating...[/dim]") + for event in orchestrator.stream_run(args.query): + if event["type"] == "thought": + console.print(f"[dim]Step {event['step']}: {event['thought']}[/dim]") + elif event["type"] == "tool": + console.print(f"[cyan] → {event['tool_name']}[/cyan]") + elif event["type"] == "answer": + console.print("\n[bold cyan]==================== ANSWER ====================[/bold cyan]\n") + console.print(Markdown(event["answer"])) + console.print("\n[bold cyan]=================================================[/bold cyan]\n") + console.print(f"[dim]Used observations: {event['kept_observations']}[/dim]") + return + + print("Initialization complete. Agent mode active.") + print(f"Available tools: {', '.join(toolkit.available_tools)}") + print("Type 'exit' or 'quit' to end the session.") + + while True: + try: + q = input("\nAsk > ").strip() + if not q: + continue + if q.lower() in {"exit", "quit"}: + print("Goodbye!") + break + + console.print("\n[dim]Investigating...[/dim]") + for event in orchestrator.stream_run(q): + if event["type"] == "thought": + console.print(f"[dim]Step {event['step']}: {event['thought']}[/dim]") + elif event["type"] == "tool": + console.print(f"[cyan] → {event['tool_name']}[/cyan]") + elif event["type"] == "answer": + console.print("\n[bold cyan]==================== ANSWER ====================[/bold cyan]\n") + console.print(Markdown(event["answer"])) + console.print("\n[bold cyan]=================================================[/bold cyan]\n") + console.print(f"[dim]Used observations: {event['kept_observations']}[/dim]") + + except KeyboardInterrupt: + print("\nGoodbye!") + break + + def run_chat_session(args: argparse.Namespace, cfg: RAGConfig): """ Initializes artifacts and runs the main interactive chat loop. @@ -326,6 +421,12 @@ def run_chat_session(args: argparse.Namespace, cfg: RAGConfig): print("Please ensure you have run 'index' mode first.") sys.exit(1) + # If query provided, run it once and exit + if args.query: + ans = get_answer(args.query, cfg, args, logger, console, artifacts=artifacts) + logger.log_generation(ans, {"max_tokens": cfg.max_gen_tokens, "model_path": cfg.gen_model}) + return + print("Initialization complete. You can start asking questions!") print("Type 'exit' or 'quit' to end the session.") while True: @@ -373,7 +474,10 @@ def main(): if args.mode == "index": run_index_mode(args, cfg) elif args.mode == "chat": - run_chat_session(args, cfg) + if cfg.use_agent: + run_agent_chat_session(args, cfg) + else: + run_chat_session(args, cfg) if __name__ == "__main__": diff --git a/src/retriever.py b/src/retriever.py index 32ac7a80..e314b76d 100644 --- a/src/retriever.py +++ b/src/retriever.py @@ -267,7 +267,6 @@ def _lemmatize_word(word: str, lemmatizer) -> str: @staticmethod def _extract_keywords(query: str) -> List[str]: """Extract keywords from query by removing stopwords and lemmatizing.""" - stopwords = { "the", "is", "at", "which", "on", "for", "a", "an", "and", "or", "in", "to", "of", "by", "with", "that", "this", "it", "as", "are", "was", @@ -282,4 +281,4 @@ def _extract_keywords(query: str) -> List[str]: if not cleaned or cleaned in stopwords: continue keywords.append(IndexKeywordRetriever._lemmatize_word(cleaned, lemmatizer)) - return keywords \ No newline at end of file + return keywords diff --git a/tests/benchmarks.yaml b/tests/benchmarks.yaml index ae106cb8..152dd291 100644 --- a/tests/benchmarks.yaml +++ b/tests/benchmarks.yaml @@ -1,74 +1,4 @@ benchmarks: - - id: "aggregation_grouping" - question: "How do aggregation with grouping work?" - expected_answer: "Aggregation partitions tuples by grouping attributes and applies functions like sum, avg, min, max to each group, producing one result per group." - keywords: ["aggregation", "grouping", "generalized projection", "ignore nulls", "attribute renaming", "duplicates"] - similarity_threshold: 0.8 - ideal_retrieved_chunks: [1039, 1040, 1496, 1497, 714] - - - id: "acid_properties" - question: "What are the ACID properties of transactions?" - expected_answer: "Atomicity ensures a transaction's actions are all-or-nothing, enforced by abort/rollback and recovery that can undo partial effects; consistency requires each transaction to preserve database integrity when run alone and relies on the scheduler to admit serializable, recoverable, and preferably cascadeless schedules; isolation makes concurrent executions equivalent to some serial order, commonly achieved with two-phase locking variants that prevent reads of uncommitted data; durability guarantees committed effects persist across crashes via logging to stable storage and redo on restart;" - keywords: ["atomicity", "consistency", "isolation", "durability", "recoverable", "checkpoints", "two-phase locking", "two-phase commit"] - similarity_threshold: 0.82 - ideal_retrieved_chunks: [1143, 1142, 1145, 1146, 1148] - - - id: "bptree" - question: "How does a B+ tree index organize keys and support search, insert, and delete, and why is it preferred over binary trees for disk-based access" - expected_answer: "B+-trees match node size to a disk page, giving very high fan-out and a shallow, height-balanced tree, so searches/updates require few page I/Os." - keywords: ["fan-out", "leaf linkage", "merge", "balanced height", "reduced height"] - similarity_threshold: 0.78 - ideal_retrieved_chunks: [908, 909, 940, 937, 938] - - - id: "fd_normalization" - question: "What are functional dependencies?" - expected_answer: "A functional dependency X -> Y asserts that tuples agreeing on X must agree on Y." - keywords: ["common key", "lossless join", "dependency preservation", "normalization", "superkey"] - similarity_threshold: 0.7 - ideal_retrieved_chunks: [438, 439, 463, 451, 484] - - - id: "sql_isolation" - question: "What isolation guarantees does SQL provide by default?" - expected_answer: "Serializable. In the SQL standard, the default isolation level is Serializable, which guarantees that the outcome of concurrently executing transactions is equivalent to some serial (one-at-a-time) order of those transactions—thereby preventing dirty reads, nonrepeatable reads, and phantoms." - keywords: ["serializable", "read committed", "repeatable read", "dirty read", "nonrepeatable read", "phantom", "two-phase locking", "predicate locking"] - similarity_threshold: 0.7 - ideal_retrieved_chunks: [1173, 1174, 1142, 1143, 1172] - - - id: "primary_foreign_keys" - question: "Explain primary keys and foreign keys" - expected_answer: "A primary key is a set of one or more attributes that uniquely identifies each tuple in a relation, chosen from candidate keys which are minimal superkeys; primary key attributes are underlined in schema diagrams and cannot have null values. A foreign key is a set of attributes in one relation (the referencing relation) that references the primary key of another relation (the referenced relation), establishing a referential integrity constraint that requires values in the foreign key to match values in the referenced primary key, thereby linking related data across tables." - keywords: ["primary key", "foreign key", "unique identifier", "referential integrity", "candidate key", "superkey"] - similarity_threshold: 0.72 - ideal_retrieved_chunks: [90, 91, 119, 93, 94] - - - id: "database_schema" - question: "What is a database schema" - expected_answer: "A database schema is the overall logical design and structure of the database, analogous to variable declarations in a program, defining the relations, their attributes, data types, and constraints including primary keys and foreign keys. The schema remains relatively stable over time, while a database instance represents the actual collection of data stored at a particular moment, with values that change as information is inserted, deleted, or modified." - keywords: ["database schema", "logical design", "structure", "database instance", "relations", "attributes", "constraints"] - similarity_threshold: 0.70 - ideal_retrieved_chunks: [49, 50, 51, 250, 60] - - - id: "book_authors" - question: "Tell me about the authors of the book" - expected_answer: "The authors of Database System Concepts Seventh Edition are Abraham Silberschatz, a Professor at Yale University and former Bell Labs vice president who is an ACM and IEEE fellow; Henry F. Korth, a Professor at Lehigh University who previously worked at Bell Labs and is also an ACM and IEEE fellow; and S. Sudarshan, the Subrao M. Nilekani Chair Professor at the Indian Institute of Technology Bombay who received his Ph.D. from the University of Wisconsin and is an ACM fellow, with research focusing on query processing and optimization." - keywords: ["Abraham Silberschatz", "Henry Korth", "Sudarshan", "Yale", "Lehigh", "IIT Bombay", "Bell Labs"] - similarity_threshold: 0.65 - ideal_retrieved_chunks: [1, 2, 0, 3, 22] - - - id: "aries_atomicity" - question: "How does the recovery manager use ARIES to ensure atomicity" - expected_answer: "The ARIES recovery algorithm ensures atomicity by maintaining a write-ahead log where all updates are recorded before being applied to the database, with each log record containing transaction ID, data item, old value, and new value. During normal operation, log records are written to stable storage before the transaction commits; if a transaction aborts or the system crashes, the recovery manager uses the log to undo uncommitted transactions by applying the old values in reverse order, ensuring that partial effects of incomplete transactions are completely rolled back and the all-or-nothing property of atomicity is preserved." - keywords: ["ARIES", "write-ahead log", "log records", "undo", "rollback", "stable storage", "atomicity", "recovery"] - similarity_threshold: 0.75 - ideal_retrieved_chunks: [1355, 1356, 1358, 1353, 1359] - - - id: "oltp_vs_analytics" - question: "Contrast the goals of Online Transaction Processing and data analytics" - expected_answer: "Online Transaction Processing (OLTP) supports a large number of concurrent users performing small, fast transactions that retrieve and update relatively small amounts of data with requirements for high throughput, low latency, and immediate consistency, typically using normalized schemas optimized for transactional integrity. Data analytics, in contrast, processes large volumes of historical data to draw conclusions and infer patterns for business intelligence and decision support, involving complex queries that scan and aggregate data across many records, often using denormalized schemas like star schemas in data warehouses optimized for read-heavy analytical workloads rather than transactional updates." - keywords: ["OLTP", "online transaction processing", "data analytics", "business intelligence", "decision support", "throughput", "data warehouse", "transactional", "analytical"] - similarity_threshold: 0.73 - ideal_retrieved_chunks: [33, 34, 738, 739, 741] - - id: "lossy_decomposition" question: "Show me what happens during a lossy decomposition" expected_answer: "A lossy decomposition occurs when a relation R is decomposed into smaller relations R1 and R2 such that joining them back together produces spurious tuples not present in the original relation, resulting in loss of information about which attribute combinations actually existed. This happens when the intersection of R1 and R2 does not form a superkey for either relation, violating the lossless-join condition; the natural join of the decomposed relations generates extra tuples from invalid combinations, making it impossible to reconstruct the original data accurately, which is why database design insists that all decompositions must be lossless." diff --git a/tests/conftest.py b/tests/conftest.py index 57a98140..86839a59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -89,6 +89,18 @@ def pytest_addoption(parser): default=None, help="System prompt mode (overrides config)" ) + group.addoption( + "--use-agent", + action="store_true", + default=None, + help="Enable agent mode (overrides config)" + ) + group.addoption( + "--no-agent", + action="store_true", + default=None, + help="Disable agent mode (overrides config)" + ) # === Testing Options === group.addoption( @@ -176,8 +188,23 @@ def config(pytestconfig): # Query Enhancement (HyDE) "use_hyde": cfg.get("use_hyde", False), "hyde_max_tokens": cfg.get("hyde_max_tokens", 100), + + # Agent Mode + "agent_reasoning_limit": cfg.get("agent_reasoning_limit", 5), + "agent_tool_limit": cfg.get("agent_tool_limit", 20), } + # Handle agent mode + use_agent_cli = pytestconfig.getoption("--use-agent") + no_agent_cli = pytestconfig.getoption("--no-agent") + + if use_agent_cli: + merged_config["use_agent"] = True + elif no_agent_cli: + merged_config["use_agent"] = False + else: + merged_config["use_agent"] = cfg.get("use_agent", False) + # Handle enable/disable chunks disable_chunks_cli = pytestconfig.getoption("--disable-chunks") diff --git a/tests/metrics/async_llm_judge.py b/tests/metrics/async_llm_judge.py index 4a819ef4..f72fca20 100644 --- a/tests/metrics/async_llm_judge.py +++ b/tests/metrics/async_llm_judge.py @@ -2,6 +2,7 @@ from pathlib import Path from datetime import datetime import json +import os import threading import time from concurrent.futures import ThreadPoolExecutor @@ -11,6 +12,12 @@ from tests.metrics.base import MetricBase from tests.metrics.llm_judge import GradingResult + +def _has_api_key() -> bool: + """Check if Google API key is available.""" + return bool(os.environ.get("GOOGLE_API_KEY") or os.environ.get("GOOGLE_GENAI_API_KEY")) + + # Shared state for async grading _results_lock = threading.Lock() _grading_results: Dict[str, Dict] = {} @@ -39,10 +46,14 @@ def __init__(self, log_dir: Optional[Path] = None): self.log_dir = Path(log_dir) self.log_dir.mkdir(parents=True, exist_ok=True) self.results_file = self.log_dir / "async_llm_results.json" + self._available = _has_api_key() - # Initialize client once - if not _initialized: - _lazy_init() + # Initialize client once only if API key is available + if self._available and not _initialized: + try: + _lazy_init() + except Exception: + self._available = False @property def name(self) -> str: @@ -53,7 +64,7 @@ def weight(self) -> float: return 0.0 def is_available(self) -> bool: - return True + return self._available def calculate(self, answer: str, expected: str, keywords: Optional[List[str]] = None) -> float: """ diff --git a/tests/metrics/chunk_retrieval.py b/tests/metrics/chunk_retrieval.py index 8ae7f53c..240dbe4f 100644 --- a/tests/metrics/chunk_retrieval.py +++ b/tests/metrics/chunk_retrieval.py @@ -16,6 +16,10 @@ def name(self) -> str: def calculate(self, ideal_retrieved_chunks: List[int], retrieved_chunks) -> float: + # Handle None/empty inputs (e.g. agent mode doesn't provide chunk info) + if not retrieved_chunks or not ideal_retrieved_chunks: + return 0.0 + print("ideal_retrieved_chunks: ", ideal_retrieved_chunks) print("retrieved_chunks: ", [chunk["chunk_id"] for chunk in retrieved_chunks]) found_chunks = [chunk["chunk_id"] for chunk in retrieved_chunks if chunk["chunk_id"] in ideal_retrieved_chunks] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 36f76c6a..fd5305bb 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -58,6 +58,10 @@ def print_test_config(config, scorer): print(f" Chunks Enabled: {not config['disable_chunks']}") print(f" Golden Chunks: {config['use_golden_chunks']}") print(f" HyDE Enabled: {config.get('use_hyde', False)}") + print(f" Agent Mode: {config.get('use_agent', False)}") + if config.get('use_agent', False): + print(f" • Reasoning Limit: {config.get('agent_reasoning_limit', 5)}") + print(f" • Tool Limit: {config.get('agent_tool_limit', 20)}") print(f" Output Mode: {config['output_mode']}") print(f" Metrics: {', '.join(active_metrics)}") print(f"{'='*60}\n") @@ -147,6 +151,9 @@ def run_benchmark(benchmark, config, results_dir, scorer): "system_prompt_mode": config["system_prompt_mode"], "disable_chunks": config["disable_chunks"], "use_golden_chunks": config["use_golden_chunks"], + "use_agent": config.get("use_agent", False), + "agent_reasoning_limit": config.get("agent_reasoning_limit", 5), + "agent_tool_limit": config.get("agent_tool_limit", 20), } } @@ -173,7 +180,6 @@ def get_tokensmith_answer(question, config, golden_chunks=None): Returns: tuple: (Generated answer, chunks_info list, hyde_query) """ - from src.main import get_answer from src.instrumentation.logging import init_logger, get_logger from src.config import RAGConfig from src.retriever import BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts @@ -188,19 +194,21 @@ def get_tokensmith_answer(question, config, golden_chunks=None): ) # Create RAGConfig from our test config + # Note: chunk_config is computed in __post_init__ from chunk_mode/chunk_size/chunk_overlap cfg = RAGConfig( - chunk_config=RAGConfig.get_chunk_config(config), + chunk_mode=config.get("chunk_mode", "recursive_sections"), + chunk_size=config.get("recursive_chunk_size", 2000), + chunk_overlap=config.get("recursive_overlap", 200), top_k=config.get("top_k", 5), - pool_size=config.get("pool_size", 60), + num_candidates=config.get("pool_size", 60), embed_model=config.get("embed_model"), - ensemble_method=config.get("retrieval_method", "rrf"), - rrf_k=60, + ensemble_method=config.get("ensemble_method", "rrf"), + rrf_k=config.get("rrf_k", 60), ranker_weights=config.get("ranker_weights", {"faiss": 0.6, "bm25": 0.4}), - rerank_mode=config.get("rerank_mode", "none"), - seg_filter=config.get("seg_filter", None), + rerank_mode=config.get("rerank_mode", ""), + gen_model=config.get("model_path"), system_prompt_mode=config.get("system_prompt_mode", "baseline"), max_gen_tokens=config.get("max_gen_tokens", 400), - model_path=config.get("model_path"), disable_chunks=config.get("disable_chunks", False), use_golden_chunks=config.get("use_golden_chunks", False), output_mode=config.get("output_mode", "html"), @@ -210,10 +218,15 @@ def get_tokensmith_answer(question, config, golden_chunks=None): use_indexed_chunks=config.get("use_indexed_chunks", False), extracted_index_path=config.get("extracted_index_path", "data/extracted_index.json"), page_to_chunk_map_path=config.get("page_to_chunk_map_path", "index/sections/textbook_index_page_to_chunk_map.json"), + use_agent=config.get("use_agent", False), + agent_reasoning_limit=config.get("agent_reasoning_limit", 5), + agent_tool_limit=config.get("agent_tool_limit", 20), ) # Print status - if golden_chunks and config["use_golden_chunks"]: + if config.get("use_agent", False): + print(f" 🤖 Agent mode enabled") + elif golden_chunks and config["use_golden_chunks"]: print(f" 📌 Using {len(golden_chunks)} golden chunks") elif config["disable_chunks"]: print(f" 📭 No chunks (baseline mode)") @@ -226,53 +239,90 @@ def get_tokensmith_answer(question, config, golden_chunks=None): logger = get_logger() # Run the query through the main pipeline - artifacts_dir = cfg.make_artifacts_directory() - faiss_index, bm25_index, chunks, sources = load_artifacts( + artifacts_dir = cfg.get_artifacts_directory() + faiss_index, bm25_index, chunks, sources, metadata = load_artifacts( artifacts_dir=artifacts_dir, index_prefix=config["index_prefix"] ) - retrievers = [ - FAISSRetriever(faiss_index, cfg.embed_model), - BM25Retriever(bm25_index) - ] - - # Add index keyword retriever if weight > 0 - if cfg.ranker_weights.get("index_keywords", 0) > 0: - retrievers.append( - IndexKeywordRetriever(cfg.extracted_index_path, cfg.page_to_chunk_map_path) + # Check if agent mode is enabled + if config.get("use_agent", False): + # Use agent orchestrator path + from src.agent import AgentToolkit + from src.agent.orchestrator import AgentOrchestrator, AgentConfig + + toolkit = AgentToolkit( + faiss_index=faiss_index, + chunks=chunks, + sources=sources, + embed_model=cfg.embed_model, + markdown_path="data/book_with_pages.md", + summaries_path="data/section_summaries.json", ) - - ranker = EnsembleRanker( - ensemble_method=cfg.ensemble_method, - weights=cfg.ranker_weights, - rrf_k=int(cfg.rrf_k) - ) - - # Package artifacts for reuse - artifacts = { - "chunks": chunks, - "sources": sources, - "retrievers": retrievers, - "ranker": ranker - } - - result = get_answer( - question=question, - cfg=cfg, - args=args, - logger=logger, - artifacts=artifacts, - console=None, - golden_chunks=golden_chunks, - is_test_mode=True - ) - - # Handle return value (answer, chunks_info, hyde_query) or just answer - if isinstance(result, tuple): - generated, chunks_info, hyde_query = result + + agent_config = AgentConfig( + reasoning_limit=cfg.agent_reasoning_limit, + tool_limit=cfg.agent_tool_limit, + max_generation_tokens=cfg.max_gen_tokens, + ) + + orchestrator = AgentOrchestrator( + toolkit=toolkit, + model_path=args.model_path or cfg.model_path, + config=agent_config, + ) + + # Run agent and get answer + result = orchestrator.run(question) + generated = result["answer"] + chunks_info = None # Agent mode doesn't provide chunks_info in the same format + hyde_query = None + else: - generated, chunks_info, hyde_query = result, None, None + # Use standard get_answer path + from src.main import get_answer + + retrievers = [ + FAISSRetriever(faiss_index, cfg.embed_model), + BM25Retriever(bm25_index) + ] + + # Add index keyword retriever if weight > 0 + if cfg.ranker_weights.get("index_keywords", 0) > 0: + retrievers.append( + IndexKeywordRetriever(cfg.extracted_index_path, cfg.page_to_chunk_map_path) + ) + + ranker = EnsembleRanker( + ensemble_method=cfg.ensemble_method, + weights=cfg.ranker_weights, + rrf_k=int(cfg.rrf_k) + ) + + # Package artifacts for reuse + artifacts = { + "chunks": chunks, + "sources": sources, + "retrievers": retrievers, + "ranker": ranker + } + + result = get_answer( + question=question, + cfg=cfg, + args=args, + logger=logger, + artifacts=artifacts, + console=None, + golden_chunks=golden_chunks, + is_test_mode=True + ) + + # Handle return value (answer, chunks_info, hyde_query) or just answer + if isinstance(result, tuple): + generated, chunks_info, hyde_query = result + else: + generated, chunks_info, hyde_query = result, None, None # Clean answer - extract up to end token if present generated = clean_answer(generated)