diff --git a/config/config.yaml b/config/config.yaml index 3c87545a..fa278243 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -28,4 +28,9 @@ kg_pipeline: corpus_description: "Database System Concepts, 7th edition by Silberschatz et al." min_cooccurrence: 0 top_n: 10 - + canonicalization: + llm_model: "openai/gpt-4o-mini" + similarity_threshold: 0.78 + max_group_size: 30 + batch_size: 15 + embed_model: "sentence-transformers/all-MiniLM-L6-v2" diff --git a/src/config.py b/src/config.py index 94de232a..e0c408bc 100644 --- a/src/config.py +++ b/src/config.py @@ -55,6 +55,13 @@ class RAGConfig: enable_history: bool = True max_history_turns: int = 3 + # knowledge graph retrieval + kg_graph_dir: str = "" + kg_beta: float = 0.5 # blend weight: 0 = node-only, 1 = section-tree-only + kg_heading_alpha: float = 0.5 # heading sim vs KG keyword blend: 1 = heading-only, 0 = KG-only + kg_inheritance_decay: float = 0.5 # parent→child score decay in top-down propagation + + # index parameters use_indexed_chunks: bool = False extracted_index_path: os.PathLike = "data/extracted_index.json" diff --git a/src/knowledge_graph/analysis.py b/src/knowledge_graph/analysis.py new file mode 100644 index 00000000..5949644e --- /dev/null +++ b/src/knowledge_graph/analysis.py @@ -0,0 +1,140 @@ +import logging +from itertools import combinations + +import networkx as nx + +from src.knowledge_graph.models import ( + DifficultyCategory, + DifficultyComponents, + DifficultyScore, + QueryAnalysisResult, + QueryFeatures, +) +from src.knowledge_graph.query import CanonicalLookup, extract_query_nodes + +logger = logging.getLogger(__name__) + +# Scoring thresholds: [easy_max, medium_max] → scores [0, 1, 2] +# Each dimension contributes 0–2; total 0–10 maps to EASY/MEDIUM/HARD. +_MULTIHOP_THRESHOLDS = [1, 2] # path hops: ≤1 direct, ≤2 one bridge, >2 multi-hop +_FRAGMENTATION_THRESHOLDS = [1, 2] # components: 1 connected, 2 partly split, >2 fragmented +_SUBGRAPH_SIZE_THRESHOLDS = [20, 60] # subgraph nodes: small, moderate, large +_BRANCHING_THRESHOLDS = [3, 6] # avg degree: low, moderate, high fan-out +_DISPERSION_THRESHOLDS = [2, 4] # source docs: local, moderate, spread across many + +# Simple heuristic thresholds for categorizing overall difficulty based on total score (0–10) +_CATEGORY_THRESHOLDS = [3, 7] # total score: easy (≤3), medium (≤7), hard (>7) + + +def extract_query_subgraph(query_nodes: list[str], graph: nx.Graph) -> nx.Graph: + """Return the subgraph spanning *query_nodes* and the shortest paths between them.""" + subgraph_nodes = set(query_nodes) + for u, v in combinations(query_nodes, 2): + if nx.has_path(graph, u, v): + try: + path = nx.shortest_path(graph, u, v) + subgraph_nodes.update(path) + except nx.NetworkXNoPath: + pass + return graph.subgraph(subgraph_nodes).copy() + + +def compute_difficulty_features( + query: str, + graph: nx.Graph, + canonical_lookup: CanonicalLookup | None = None, +) -> QueryFeatures: + """Compute graph-structural features for *query*. + + Returns a zeroed ``QueryFeatures`` if no query nodes are found in *graph*. + """ + query_nodes = extract_query_nodes(query, graph, canonical_lookup) + logger.debug("Query nodes: %s", query_nodes) + if not query_nodes: + return QueryFeatures() + + subgraph = extract_query_subgraph(query_nodes, graph) + + component_count = nx.number_connected_components(subgraph) + + path_lengths = [] + for u, v in combinations(query_nodes, 2): + if nx.has_path(graph, u, v): + try: + path_lengths.append(nx.shortest_path_length(graph, u, v)) + except nx.NetworkXNoPath: + pass + + max_path_length = max(path_lengths) if path_lengths else 0 + avg_path_length = sum(path_lengths) / len(path_lengths) if path_lengths else 0.0 + + degrees = dict(subgraph.degree()) + max_degree = max(degrees.values()) if degrees else 0 + avg_degree = sum(degrees.values()) / len(degrees) if degrees else 0.0 + + chunk_ids: set[int] = set() + for _, data in subgraph.nodes(data=True): + chunk_ids.update(data.get("chunk_ids", [])) + for _, _, data in subgraph.edges(data=True): + chunk_ids.update(data.get("chunk_ids", [])) + + return QueryFeatures( + query_node_count=len(query_nodes), + component_count=component_count, + max_path_length=max_path_length, + avg_path_length=avg_path_length, + avg_degree=avg_degree, + max_degree=max_degree, + subgraph_node_count=subgraph.number_of_nodes(), + subgraph_edge_count=subgraph.number_of_edges(), + doc_count=len(chunk_ids), + ) + + +def _map_to_score( + value: int | float, + thresholds: list[int | float], + scores: list[int | DifficultyCategory], +): + for threshold, score in zip(thresholds, scores): + if value <= threshold: + return score + return scores[-1] + + +def compute_difficulty_score(features: QueryFeatures) -> DifficultyScore: + multihop = _map_to_score(features.max_path_length, _MULTIHOP_THRESHOLDS, [0, 1, 2]) + fragmentation = _map_to_score(features.component_count, _FRAGMENTATION_THRESHOLDS, [0, 1, 2]) + subgraph_size = _map_to_score(features.subgraph_node_count, _SUBGRAPH_SIZE_THRESHOLDS, [0, 1, 2]) + branching = _map_to_score(features.avg_degree, _BRANCHING_THRESHOLDS, [0, 1, 2]) + dispersion = _map_to_score(features.doc_count, _DISPERSION_THRESHOLDS, [0, 1, 2]) + + total = multihop + fragmentation + subgraph_size + branching + dispersion + category = _map_to_score( + total, + _CATEGORY_THRESHOLDS, + [DifficultyCategory.EASY, DifficultyCategory.MEDIUM, DifficultyCategory.HARD], + ) + + return DifficultyScore( + score=total, + category=category, + components=DifficultyComponents( + multihop=multihop, + fragmentation=fragmentation, + subgraph_size=subgraph_size, + branching=branching, + dispersion=dispersion, + ), + ) + + +def analyze_query( + query: str, + graph: nx.Graph, + canonical_lookup: CanonicalLookup | None = None, +) -> QueryAnalysisResult: + """Run the full difficulty analysis pipeline for *query*.""" + features = compute_difficulty_features(query, graph, canonical_lookup) + difficulty = compute_difficulty_score(features) + return QueryAnalysisResult(query=query, features=features, difficulty=difficulty) diff --git a/src/knowledge_graph/build.py b/src/knowledge_graph/build.py index df3bd3d0..c221962d 100644 --- a/src/knowledge_graph/build.py +++ b/src/knowledge_graph/build.py @@ -1,8 +1,8 @@ import os -import pickle -import argparse import json import shutil +import pickle +import argparse from time import strftime from src.knowledge_graph.models import Chunk diff --git a/src/knowledge_graph/canonicalizer.py b/src/knowledge_graph/canonicalizer.py new file mode 100644 index 00000000..67464075 --- /dev/null +++ b/src/knowledge_graph/canonicalizer.py @@ -0,0 +1,312 @@ +import json +import logging +from collections import Counter +from typing import Any + +import numpy as np +from scipy.cluster.hierarchy import fcluster, linkage +from scipy.spatial.distance import squareform +from sklearn.metrics.pairwise import cosine_similarity + +from sentence_transformers import SentenceTransformer +from src.knowledge_graph.models import ExtractionResult, CanonicalizationResult +from src.knowledge_graph.openrouter_client import OpenRouterClient +from src.knowledge_graph.normalizer import Normalizer +from src.knowledge_graph.prompts import SYNONYM_PROMPT, SYNONYM_SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + + +class Canonicalizer: + """Semantic canonicalization of KG keywords. + + Args: + corpus_description: Human-readable description of the corpus + (e.g. Title of the textbook or main topic of the document). + Injected into the LLM system prompt as domain context. + api_key: OpenRouter API key for the LLM verification step. + embedding_model: Sentence-transformer model name for keyword embedding. + similarity_threshold: Cosine similarity threshold for complete-linkage + clustering. A group forms only when ALL pairs in it exceed this value. + max_group_size: Maximum keywords per LLM call. Oversized clusters are + force-split into fixed-size chunks before the LLM step. + llm_model: OpenRouter model identifier. + batch_size: Number of small groups (≤5 keywords) to batch per LLM call. + fallback_threshold: Cosine similarity threshold used at query time when a + keyword is not in the synonym table (embedding-based fallback). + """ + + def __init__( + self, + corpus_description: str, + api_key: str, + embedding_model: str, + similarity_threshold: float = 0.78, + max_group_size: int = 30, + llm_model: str = "openai/gpt-4o-mini", + batch_size: int = 15, + fallback_threshold: float = 0.85, + retries: int = 1, + normalizer: Normalizer | None = None, + ): + self.corpus_description = corpus_description + self.similarity_threshold = similarity_threshold + self.max_group_size = max_group_size + self.llm_model = llm_model + self.batch_size = batch_size + self.fallback_threshold = fallback_threshold + self._normalizer = normalizer or Normalizer() + self.retries = retries + self._client = OpenRouterClient(api_key, retries=retries) + + logger.info("Loading embedding model: %s", embedding_model) + self._model = SentenceTransformer(embedding_model) + self._embedding_model_name = embedding_model + self._llm_calls = 0 + + def get_config(self) -> dict[str, Any]: + return { + "class": self.__class__.__name__, + "corpus_description": self.corpus_description, + "embedding_model": self._embedding_model_name, + "similarity_threshold": self.similarity_threshold, + "max_group_size": self.max_group_size, + "llm_model": self.llm_model, + "batch_size": self.batch_size, + "fallback_threshold": self.fallback_threshold, + "retries": self.retries, + } + + def canonicalize( + self, extractions: list[ExtractionResult] + ) -> tuple[list[ExtractionResult], CanonicalizationResult]: + """Run canonicalization on a list of extraction results. + + Returns: + Updated extractions (nodes replaced by canonical forms) and a + CanonicalizationResult carrying the artifacts and run statistics. + """ + all_keywords = self._collect_keywords(extractions) + n = len(all_keywords) + logger.info("Canonicalizing %d unique keywords…", n) + + # 2a — embed + logger.info(" [2a] Embedding keywords…") + embeddings = self._embed(all_keywords) + + # 2b — cluster + logger.info(" [2b] Complete-linkage clustering (θ=%.2f)…", + self.similarity_threshold) + groups = self._cluster(all_keywords, embeddings) + singletons = [g[0] for g in groups if len(g) == 1] + non_singletons = [g for g in groups if len(g) > 1] + logger.info( + " %d singletons, %d candidate groups", len( + singletons), len(non_singletons) + ) + + # 2c — LLM verification + logger.info(" [2c] LLM verification (%d groups)…", + len(non_singletons)) + self._llm_calls = 0 + synonym_table = self._verify_with_llm(non_singletons) + + # 2d — build structures + canonical_keywords = sorted( + set(synonym_table.values()) | set(singletons)) + + logger.info(" [2d] Embedding %d canonical keywords…", + len(canonical_keywords)) + canonical_embeddings = self._embed(canonical_keywords) + + counts = Counter(synonym_table.values()) + merges_performed = sum(c - 1 for c in counts.values() if c > 1) + + stats = { + "keywords_after_stage1": n, + "candidate_groups": len(non_singletons), + "singletons": len(singletons), + "merges_performed": merges_performed, + "canonical_keywords_final": len(canonical_keywords), + "llm_calls": self._llm_calls, + } + + logger.info( + "Canonicalization done: %d → %d keywords, %d merges, %d LLM calls", + n, len(canonical_keywords), merges_performed, self._llm_calls, + ) + + updated = self._apply(extractions, synonym_table) + result = CanonicalizationResult( + synonym_table=synonym_table, + canonical_keywords=canonical_keywords, + canonical_embeddings=canonical_embeddings, + stats=stats, + ) + return updated, result + + @staticmethod + def _collect_keywords(extractions: list[ExtractionResult]) -> list[str]: + # List preserves stable order for embedding index alignment, set provides dedup. + seen: set[str] = set() + keywords: list[str] = [] + for er in extractions: + for kw in er.keywords: + if kw not in seen: + keywords.append(kw) + seen.add(kw) + return keywords + + def _embed(self, keywords: list[str]) -> np.ndarray: + return self._model.encode(keywords, show_progress_bar=False) + + def _cluster(self, keywords: list[str], embeddings: np.ndarray) -> list[list[str]]: + """Complete-linkage clustering. + + A group forms only when ALL pairs within it have cosine similarity ≥ + self.similarity_threshold (equivalently, distance ≤ 1 − threshold). + Oversized groups are force-split into max_group_size chunks. + """ + n = len(keywords) + if n == 1: + return [keywords] + + sim = cosine_similarity(embeddings) + np.fill_diagonal(sim, 1.0) + dist = np.clip(1.0 - sim, 0.0, None) + + condensed = squareform(dist, checks=False) + Z = linkage(condensed, method="complete") + labels = fcluster(Z, t=1.0 - self.similarity_threshold, + criterion="distance") + + raw_groups: dict[int, list[str]] = {} + for kw, label in zip(keywords, labels): + raw_groups.setdefault(int(label), []).append(kw) + + result: list[list[str]] = [] + for group in raw_groups.values(): + if len(group) <= self.max_group_size: + result.append(group) + else: + for i in range(0, len(group), self.max_group_size): + result.append(group[i: i + self.max_group_size]) + return result + + def _verify_with_llm(self, groups: list[list[str]]) -> dict[str, str]: + """Return a partial synonym table for all keywords in non-singleton groups.""" + partial: dict[str, str] = {} + + small = [g for g in groups if len(g) <= 5] + large = [g for g in groups if len(g) > 5] + + for i in range(0, len(small), self.batch_size): + partial.update(self._llm_call(small[i: i + self.batch_size])) + + for group in large: + partial.update(self._llm_call([group])) + + return partial + + def _normalize_kw(self, kw: str) -> str: + """Normalize a single keyword using the configured Normalizer or strip+lower.""" + result = self._normalizer.normalize([kw]) + return result[0] if result else kw.strip().lower() + + def _llm_call(self, groups: list[list[str]]) -> dict[str, str]: + """One OpenRouter API call covering a batch of candidate groups. + + Returns a keyword → canonical mapping only for keywords that the LLM + confirms are true synonyms. Standalone or unmentioned keywords are omitted; + callers treat a missing entry as "no synonym found". + """ + groups_text = "\n".join( + f"Group {i + 1}: {json.dumps(g)}" for i, g in enumerate(groups) + ) + + system_prompt = SYNONYM_SYSTEM_PROMPT.format( + corpus_description=self.corpus_description) + user_prompt = SYNONYM_PROMPT.format(groups_text=groups_text) + + partial: dict[str, str] = {} + + try: + content = self._client.chat( + model=self.llm_model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + ) + self._llm_calls += 1 + + parsed = json.loads(content) + for group_result in parsed.get("groups", []): + for sg in group_result.get("synonym_groups", []): + canonical = self._normalize_kw(sg.get("canonical", "")) + for member in sg.get("members", []): + if member: + partial[self._normalize_kw(member)] = canonical + + except Exception as e: + logger.warning( + "LLM call failed after all attempts (%s) — batch skipped", e) + + return partial + + @staticmethod + def _apply( + extractions: list[ExtractionResult], synonym_table: dict[str, str] + ) -> list[ExtractionResult]: + updated = [] + for er in extractions: + seen: set[str] = set() + canonical_nodes: list[str] = [] + for kw in er.keywords: + canonical = synonym_table.get(kw, kw) + if canonical not in seen: + canonical_nodes.append(canonical) + seen.add(canonical) + updated.append(ExtractionResult( + chunk_id=er.chunk_id, keywords=canonical_nodes)) + return updated + + +class MockCanonicalizer: + """Drop-in replacement for Canonicalizer that replays a pre-saved result. + + Loads a cache file produced by generate_canon_cache.py and returns the + stored extractions and CanonicalizationResult without running any model + or LLM. Useful for iterating on pipeline stages that follow canonicalization. + + Args: + cache_path: Path to the JSON cache file (relative to repo root or absolute). + """ + + def __init__(self, cache_path: str): + with open(cache_path, "r", encoding="utf-8") as f: + data = json.load(f) + + self._updated_extractions = [ + ExtractionResult(chunk_id=e["chunk_id"], keywords=e["keywords"]) + for e in data["updated_extractions"] + ] + self._result = CanonicalizationResult( + synonym_table=data["synonym_table"], + canonical_keywords=data["canonical_keywords"], + canonical_embeddings=np.array( + data["canonical_embeddings"], dtype=np.float32), + stats=data.get("stats", {}), + ) + logger.warning("MockCanonicalizer: loaded cache from %s", cache_path) + + def get_config(self) -> dict[str, Any]: + return {"class": self.__class__.__name__} + + def canonicalize( + self, extractions: list[ExtractionResult] + ) -> tuple[list[ExtractionResult], CanonicalizationResult]: + logger.warning( + "MockCanonicalizer: returning cached canonicalization, input ignored") + return self._updated_extractions, self._result diff --git a/src/knowledge_graph/io.py b/src/knowledge_graph/io.py new file mode 100644 index 00000000..4dc9fe51 --- /dev/null +++ b/src/knowledge_graph/io.py @@ -0,0 +1,106 @@ +import json +import os + +import networkx as nx +import numpy as np + +from src.knowledge_graph.build import RUNS_DIR # re-exported for callers # noqa: F401 +from src.knowledge_graph.section_tree import SectionTree, load_section_tree + + +def load_graph(path: str) -> nx.Graph: + """Load a NetworkX graph from a ``graph.json`` node-link file.""" + with open(path, "r", encoding="utf-8") as f: + return nx.node_link_graph(json.load(f)) + + +def load_run_chunks(path: str) -> dict[int, str]: + """Load chunk text from a ``chunks.json`` run artifact. + + JSON object keys must be strings, so the file stores chunk IDs as strings. + This function converts them back to ``int``. + + Returns a mapping of integer chunk ID → text. + """ + with open(path, "r", encoding="utf-8") as f: + return {int(k): v for k, v in json.load(f).items()} + + +def resolve_run_dir(path: str) -> str: + """Return the concrete run directory to load from. + + - If ``path/graph.json`` exists, *path* is already a run directory. + - If ``path/latest`` is a symlink, resolve and return it. + - Otherwise raise ``FileNotFoundError``. + """ + if os.path.isfile(os.path.join(path, "graph.json")): + return path + latest = os.path.join(path, "latest") + if os.path.islink(latest): + resolved = os.path.realpath(latest) + if os.path.isfile(os.path.join(resolved, "graph.json")): + return resolved + raise FileNotFoundError( + f"Cannot resolve run dir from {path!r}: " + "no graph.json found and no valid 'latest' symlink." + ) + + +def load_graph_and_chunks(output_dir: str) -> tuple[nx.Graph, dict[int, str]]: + """Load the most recently persisted graph and chunks from *output_dir*. + + Accepts either a specific run directory (containing ``graph.json`` and + ``chunks.json``) or a parent ``runs/`` directory with a ``latest`` symlink. + + Returns: + ``(graph, chunks)`` where *chunks* maps ``int`` chunk IDs to text. + + Raises: + FileNotFoundError: If the run directory cannot be resolved. + """ + run_dir = resolve_run_dir(output_dir) + graph = load_graph(os.path.join(run_dir, "graph.json")) + chunks = load_run_chunks(os.path.join(run_dir, "chunks.json")) + return graph, chunks + + +def load_graph_chunks_and_tree( + output_dir: str, +) -> tuple[nx.Graph, dict[int, str], SectionTree | None]: + """Like ``load_graph_and_chunks`` but also loads the section tree. + + Returns: + ``(graph, chunks, section_tree)`` — *section_tree* is ``None`` when + ``section_tree.json`` is not present, so callers fall back gracefully + to node-only scoring. + """ + run_dir = resolve_run_dir(output_dir) + graph, chunks = load_graph_and_chunks(run_dir) + try: + tree = load_section_tree(run_dir) + except FileNotFoundError: + tree = None + return graph, chunks, tree + + +def load_canonicalization_data( + run_dir: str, +) -> tuple[dict[str, str], list[str], np.ndarray] | tuple[None, None, None]: + """Load synonym table, canonical keywords, and embeddings from a run directory. + + Returns ``(None, None, None)`` when canonicalization artifacts are absent. + """ + synonym_path = os.path.join(run_dir, "synonym_table.json") + keywords_path = os.path.join(run_dir, "canonical_keywords.json") + embeddings_path = os.path.join(run_dir, "canonical_embeddings.npy") + + if not all(os.path.exists(p) for p in [synonym_path, keywords_path, embeddings_path]): + return None, None, None + + with open(synonym_path, "r", encoding="utf-8") as f: + synonym_table: dict[str, str] = json.load(f) + with open(keywords_path, "r", encoding="utf-8") as f: + canonical_keywords: list[str] = json.load(f) + canonical_embeddings = np.load(embeddings_path) + + return synonym_table, canonical_keywords, canonical_embeddings diff --git a/src/knowledge_graph/models.py b/src/knowledge_graph/models.py index ef66ff21..9e102bf1 100644 --- a/src/knowledge_graph/models.py +++ b/src/knowledge_graph/models.py @@ -1,5 +1,8 @@ from dataclasses import dataclass, field from typing import Any +from enum import Enum + +import numpy as np import yaml @@ -16,6 +19,92 @@ class ExtractionResult: keywords: list[str] = field(default_factory=list) +@dataclass +class CanonicalizationResult: + synonym_table: dict[str, str] + canonical_keywords: list[str] + canonical_embeddings: np.ndarray + stats: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class QueryFeatures: + query_node_count: int = 0 + component_count: int = 0 + max_path_length: int = 0 + avg_path_length: float = 0.0 + avg_degree: float = 0.0 + max_degree: int = 0 + subgraph_node_count: int = 0 + subgraph_edge_count: int = 0 + doc_count: int = 0 + + def to_dict(self) -> dict: + return { + "query_node_count": self.query_node_count, + "component_count": self.component_count, + "max_path_length": self.max_path_length, + "avg_path_length": self.avg_path_length, + "avg_degree": self.avg_degree, + "max_degree": self.max_degree, + "subgraph_node_count": self.subgraph_node_count, + "subgraph_edge_count": self.subgraph_edge_count, + "doc_count": self.doc_count, + } + + +class DifficultyCategory(Enum): + EASY = "easy" + MEDIUM = "medium" + HARD = "hard" + + +@dataclass +class DifficultyComponents: + multihop: int + fragmentation: int + subgraph_size: int + branching: int + dispersion: int + + def to_dict(self) -> dict: + return { + "multihop": self.multihop, + "fragmentation": self.fragmentation, + "subgraph_size": self.subgraph_size, + "branching": self.branching, + "dispersion": self.dispersion, + } + + +@dataclass +class DifficultyScore: + score: int + category: DifficultyCategory + components: DifficultyComponents + + def to_dict(self) -> dict: + return { + "score": self.score, + "category": self.category.value, + "components": self.components.__dict__, + } + + +@dataclass +class QueryAnalysisResult: + query: str + features: QueryFeatures + difficulty: DifficultyScore + + def to_dict(self) -> dict: + return { + "query": self.query, + "features": self.features.to_dict(), + "difficulty": self.difficulty.to_dict(), + } + + @dataclass class RunMetadata: """Configuration and execution statistics for a pipeline run.""" @@ -30,11 +119,23 @@ def to_dict(self) -> dict: } +@dataclass +class CanonicalizationConfig: + llm_model: str = "openai/gpt-4o-mini" + embed_model: str = "sentence-transformers/all-MiniLM-L6-v2" + similarity_threshold: float = 0.78 + max_group_size: int = 30 + batch_size: int = 15 + + @dataclass class KGPipelineConfig: corpus_description: str = "" min_cooccurrence: int = 0 top_n: int = 10 + canonicalization: CanonicalizationConfig = field( + default_factory=CanonicalizationConfig + ) @classmethod def from_yaml(cls, path: str) -> "KGPipelineConfig": @@ -42,4 +143,5 @@ def from_yaml(cls, path: str) -> "KGPipelineConfig": with open(path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) kg = dict(data.get("kg_pipeline", {})) - return cls(**kg) + canon_data = kg.pop("canonicalization", {}) + return cls(**kg, canonicalization=CanonicalizationConfig(**canon_data)) diff --git a/src/knowledge_graph/ngrams.py b/src/knowledge_graph/ngrams.py new file mode 100644 index 00000000..59c9d249 --- /dev/null +++ b/src/knowledge_graph/ngrams.py @@ -0,0 +1,27 @@ +import re + +from nltk.util import ngrams + +# Regex for tokenizing KG / query text. +# Matches words (including hyphenated compounds and trailing '+'). +KW_PATTERN = r"\b\w+(?:\s*-\s*\w+)*\+?" + +# Simpler pattern for heading text (no hyphen compounds or '+' needed). +HEADING_PATTERN = r"\b\w+\b" + + +def extract_ngrams(text: str, pattern: str) -> set[str]: + """Tokenize *text*, build unigrams + bigrams + trigrams, return as a set. + + Args: + text: Input string to tokenize. + pattern: Regex pattern used to extract tokens (e.g. ``KW_PATTERN``). + + Returns: + Set of all n-gram strings (n = 1, 2, 3). + """ + tokens = re.findall(pattern, text) + all_terms = list(tokens) + for n in (2, 3): + all_terms.extend(" ".join(gram) for gram in ngrams(tokens, n)) + return set(all_terms) diff --git a/src/knowledge_graph/normalizer.py b/src/knowledge_graph/normalizer.py new file mode 100644 index 00000000..4161ad65 --- /dev/null +++ b/src/knowledge_graph/normalizer.py @@ -0,0 +1,44 @@ +import spacy + + +class Normalizer: + """Normalize keywords for consistent graph construction. + + Performs lowercasing, spaCy lemmatization, alias/abbreviation expansion, + and deduplication. + + Args: + spacy_model: Name of the spaCy model to load for lemmatization. + """ + + def __init__(self, spacy_model: str = "en_core_web_sm"): + self.nlp = spacy.load(spacy_model, disable=["ner", "parser"]) + + def _lemmatize(self, text: str) -> str: + """Return the lemmatized form of *text*.""" + doc = self.nlp(text) + return " ".join(token.lemma_ for token in doc) + + def normalize(self, keywords: list[str]) -> list[str]: + """Normalize and deduplicate a list of keywords. + Strips leading/trailing whitespace, lowercases, lemmatizes, and deduplicates. + Keeps the first occurrence of each unique normalized keyword, preserving order. + + Args: + keywords: Raw keyword strings. + + Returns: + Deduplicated, normalized keywords. + """ + result: list[str] = [] + seen: set[str] = set() # to track duplicates after normalization + for kw in keywords: + normalized = kw.strip().lower() + if not normalized: + continue + normalized = self._lemmatize(normalized) + if normalized not in seen: + seen.add(normalized) + result.append(normalized) + + return result diff --git a/src/knowledge_graph/openrouter_client.py b/src/knowledge_graph/openrouter_client.py index f332094a..74f45c5c 100644 --- a/src/knowledge_graph/openrouter_client.py +++ b/src/knowledge_graph/openrouter_client.py @@ -1,9 +1,3 @@ -"""Thin wrapper around the OpenRouter chat-completions endpoint. - -Centralises the POST request, auth headers, and retry loop so that -``Canonicalizer`` and ``OpenRouterExtractor`` don't duplicate them. -""" - import logging import requests diff --git a/src/knowledge_graph/pipeline.py b/src/knowledge_graph/pipeline.py index f02fdf68..a451df80 100644 --- a/src/knowledge_graph/pipeline.py +++ b/src/knowledge_graph/pipeline.py @@ -4,36 +4,27 @@ from time import time import networkx as nx +import numpy as np -from src.knowledge_graph.extractors import BaseExtractor +from src.knowledge_graph.models import Chunk, RunMetadata, CanonicalizationResult from src.knowledge_graph.linkers import BaseLinker -from src.knowledge_graph.models import Chunk, RunMetadata +from src.knowledge_graph.extractors import BaseExtractor +from src.knowledge_graph.canonicalizer import Canonicalizer + + +logger = logging.getLogger(__name__) + -from src.knowledge_graph.build import ( - CHUNKS_PKL, - META_PKL, - load_chunks, -) logger = logging.getLogger(__name__) def build_kg( output_dir: str, + chunks: list[Chunk], extractor: BaseExtractor, linker: BaseLinker, - chapter_filter: str | None = None, - exclude_chapters: list[str] | None = None, - chunk_ids: list[int] | None = None, + canonicalizer: Canonicalizer, ) -> nx.Graph: - chunks = load_chunks( - CHUNKS_PKL, - META_PKL, - chapter_filter=chapter_filter, - exclude_chapters=exclude_chapters, - chunk_ids=chunk_ids, - ) - logger.info("Loaded %d chunks", len(chunks)) - logger.info("Extracting keywords...") t0 = time() extractions = extractor.extract(chunks) @@ -42,6 +33,17 @@ def build_kg( f" {len(extractions)} extractions created in {t1 - t0:.2f} seconds" ) + logger.info("Canonicalizing keywords...") + t0 = time() + extractions, canon_result = canonicalizer.canonicalize(extractions) + t1 = time() + s = canon_result.stats + logger.info( + f" {s['keywords_after_stage1']} → {s['canonical_keywords_final']} keywords, " + f"{s['merges_performed']} merges, {s['llm_calls']} LLM calls " + f"in {t1 - t0:.2f} seconds" + ) + logger.info("Linking keywords...") t0 = time() graph = linker.link(extractions) @@ -63,8 +65,11 @@ def build_kg( ) _persist( - graph, chunks, output_dir, + graph, + chunks, + output_dir, run_metadata=run_metadata, + canonicalization_result=canon_result, ) t1 = time() logger.info(f" Graph persisted in {t1 - t0:.2f} seconds") @@ -83,7 +88,8 @@ def _persist( graph: nx.Graph, chunks: list[Chunk], output_dir: str, - run_metadata: RunMetadata | None = None, + run_metadata: RunMetadata, + canonicalization_result: CanonicalizationResult, ) -> None: os.makedirs(output_dir, exist_ok=True) @@ -95,25 +101,38 @@ def _persist( with open(os.path.join(output_dir, "chunks.json"), "w", encoding="utf-8") as f: json.dump(chunk_store, f, indent=2, ensure_ascii=False) - if run_metadata: - num_nodes = graph.number_of_nodes() - num_edges = graph.number_of_edges() - comp_list = list(nx.connected_components(graph)) - largest_comp_size = len( - max(comp_list, key=len)) if comp_list else 0 - - run_metadata.statistics["graph"] = { - "nodes": num_nodes, - "edges": num_edges, - "density": nx.density(graph), - "avg_degree": (2 * num_edges / num_nodes) if num_nodes > 0 else 0.0, - "avg_clustering": nx.average_clustering(graph), - "num_connected_components": len(comp_list), - "largest_component_size": largest_comp_size, - "max_degree": max(dict(graph.degree()).values(), default=0), - } - with open( - os.path.join(output_dir, "run_metadata.json"), "w", encoding="utf-8" - ) as f: - json.dump(run_metadata.to_dict(), f, - indent=2, ensure_ascii=False) + num_nodes = graph.number_of_nodes() + num_edges = graph.number_of_edges() + comp_list = list(nx.connected_components(graph)) + largest_comp_size = len( + max(comp_list, key=len)) if comp_list else 0 + run_metadata.statistics["graph"] = { + "nodes": num_nodes, + "edges": num_edges, + "density": nx.density(graph), + "avg_degree": (2 * num_edges / num_nodes) if num_nodes > 0 else 0.0, + "avg_clustering": nx.average_clustering(graph), + "num_connected_components": len(comp_list), + "largest_component_size": largest_comp_size, + "max_degree": max(dict(graph.degree()).values(), default=0), + } + with open( + os.path.join(output_dir, "run_metadata.json"), "w", encoding="utf-8" + ) as f: + json.dump(run_metadata.to_dict(), f, + indent=2, ensure_ascii=False) + + with open( + os.path.join(output_dir, "synonym_table.json"), "w", encoding="utf-8" + ) as f: + json.dump(canonicalization_result.synonym_table, f, indent=2, ensure_ascii=False) + + with open( + os.path.join(output_dir, "canonical_keywords.json"), "w", encoding="utf-8" + ) as f: + json.dump(canonicalization_result.canonical_keywords, f, indent=2, ensure_ascii=False) + + np.save( + os.path.join(output_dir, "canonical_embeddings.npy"), + canonicalization_result.canonical_embeddings, + ) diff --git a/src/knowledge_graph/prompts.py b/src/knowledge_graph/prompts.py index c1edd4ae..b5e1b306 100644 --- a/src/knowledge_graph/prompts.py +++ b/src/knowledge_graph/prompts.py @@ -13,9 +13,56 @@ <|im_start|>assistant """ +SYNONYM_PROMPT = """Given the following groups of keywords extracted from the corpus, \ +determine which keywords within each group are true synonyms. +{groups_text} +For each group: +1. Identify sets of TRUE synonyms (keywords that refer to the EXACT same concept and \ +are fully interchangeable word-for-word in any sentence without changing meaning). \ +Topical relatedness, part-whole relationships, abbreviation-expansion pairs with \ +different scope, and general-vs-specific pairs do NOT qualify. When in doubt, keep them separate. +2. Choose the best canonical label — prefer the form used in academic/textbook literature. +3. List keywords that are NOT synonymous with any other keyword as standalone. +Respond in JSON only: +{{ + "groups": [ + {{ + "group_id": 1, + "synonym_groups": [ + {{"canonical": "label", "members": ["kw1", "kw2"], "reason": "..."}} + ], + "standalone": ["kw_x"] + }} + ] +}} +""" + +SYNONYM_SYSTEM_PROMPT = """You are a terminology expert analyzing keywords extracted from: +{corpus_description}. Identify keywords that refer to exactly the same concept and should +be merged. Be conservative, prefer keeping terms separate over incorrectly merging distinct +concepts. +""" + OPENROUTER_KEYWORD_EXTRACTION_PROMPT = """You are a linguistic analysis expert. Analyze the provided text and identify the {top_n} most relevant and descriptive keywords or short phrases (1-3 words). Focus on terms that carry the most information density, such as technical terms, -proper nouns, and central concepts. Return the result as a raw JSON list of strings. -Do not include any other text or explanation in your response. +proper nouns, and central concepts. Return the result as a raw JSON list of strings. Do not +include any other text or explanation in your response. """ + +GRADE_PROMPT = """\ +You are evaluating a retrieval system for a question-answering application. + +Question: {query} + +Retrieved passages: +{passages} + +Rate each passage for how well it helps answer the question. +Return a JSON object with key "grades" containing one entry per passage (same order): +{{"grades": [{{"id": 1, "score": 0, "reason": "brief reason"}}]}} + +Scoring: +0 = Not relevant — passage is unrelated to the question +1 = Partially relevant — passage touches the topic but doesn't directly answer it +2 = Highly relevant — passage directly helps answer the question""" diff --git a/src/knowledge_graph/query.py b/src/knowledge_graph/query.py new file mode 100644 index 00000000..bfa9c53b --- /dev/null +++ b/src/knowledge_graph/query.py @@ -0,0 +1,311 @@ +import logging + +import networkx as nx +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity as cos_sim + +from src.retriever import Retriever +from src.knowledge_graph.io import RUNS_DIR, load_graph_and_chunks +from src.knowledge_graph.section_tree import SectionTree +from src.knowledge_graph.ngrams import KW_PATTERN, extract_ngrams +from src.knowledge_graph.normalizer import Normalizer + +logger = logging.getLogger(__name__) + +# Shared normalizer instance, spaCy model is expensive to load +_normalizer = Normalizer() + + +class CanonicalLookup: + """Resolves a normalized keyword to its canonical form at query time. + + Uses a pre-built synonym table (dict lookup, O(1)) for known keywords. + For unknown keywords, falls back to embedding-based nearest-neighbor search + against canonical keyword embeddings, gated by a similarity threshold. + + Args: + synonym_table: Mapping of normalized keyword → canonical form (synonyms only, + no identity entries). + canonical_keywords: Ordered list of canonical forms (aligned with embeddings). + canonical_embeddings: Embedding matrix for canonical keywords (shape N × D). + embedding_model: Sentence-transformer model name (must match the model used + during offline canonicalization). + fallback_threshold: Minimum cosine similarity for the embedding fallback to + accept a canonical match (default 0.85). + """ + + def __init__( + self, + synonym_table: dict[str, str], + canonical_keywords: list[str], + canonical_embeddings: np.ndarray, + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", + fallback_threshold: float = 0.85, + ): + self.synonym_table = synonym_table + self.canonical_keywords = canonical_keywords + self.canonical_embeddings = canonical_embeddings + self.fallback_threshold = fallback_threshold + self._model_name = embedding_model + self._model = None # lazy-load + + def resolve(self, keyword: str) -> str: + """Return the canonical form for *keyword*. + + 1. Dictionary lookup in synonym_table. + 2. Embedding nearest-neighbour fallback (if threshold met). + 3. Return *keyword* unchanged if no mapping found. + """ + if keyword in self.synonym_table: + return self.synonym_table[keyword] + + if self._model is None: + from sentence_transformers import SentenceTransformer + self._model = SentenceTransformer(self._model_name) + + emb = self._model.encode([keyword]) + sims = cos_sim(emb, self.canonical_embeddings)[0] + best_idx = int(np.argmax(sims)) + if sims[best_idx] >= self.fallback_threshold: + synonym = self.canonical_keywords[best_idx] + # print(f"Embedding fallback: '{keyword}' → '{synonym}' (sim={sims[best_idx]:.4f})") + return synonym + + return keyword + + +def _tokens_subsumed(short: str, long: str) -> bool: + """Return True if the tokens of *short* appear contiguously inside *long*.""" + ws, wl = short.split(), long.split() + n = len(ws) + return any(wl[i : i + n] == ws for i in range(len(wl) - n + 1)) + + +def extract_query_nodes( + query: str, + graph: nx.Graph, + canonical_lookup: CanonicalLookup | None = None, +) -> list[str]: + """Match query terms against graph node labels. + + Generates unigrams, bigrams, and trigrams from *query*, normalises them, + optionally maps each to its canonical form via *canonical_lookup*, and + returns any that are present as nodes in *graph*. Shorter nodes that are + token-level substrings of a longer matched node are dropped. + + Args: + query: Natural-language query string. + graph: The knowledge graph to match against. + canonical_lookup: Optional lookup object for mapping normalized keywords + to canonical forms. When provided, enables synonym-aware matching + and an embedding-based fallback for out-of-vocabulary terms. + + Returns: + List of matched node label strings (may be empty). + """ + terms = extract_ngrams(query, KW_PATTERN) + normalized_terms = _normalizer.normalize(terms) + + if canonical_lookup is not None: + resolved = {canonical_lookup.resolve(t) for t in normalized_terms} + else: + resolved = set(normalized_terms) + + matched = [t for t in resolved if graph.has_node(t)] + filtered = [ + n for n in matched + if not any(n != m and _tokens_subsumed(n, m) for m in matched) + ] + return filtered + + +class KGNodeRetriever(Retriever): + """Knowledge-graph retriever that scores chunks via BFS node matching. + + Scores are derived purely from graph topology: direct query-node matches + score +1.0, and neighbors at hop *k* contribute + ``neighbor_weight**k * (edge_weight / max_edge_weight)``. + All scores are normalized to [0, 1]. + + Plugs into ``EnsembleRanker`` via the standard ``Retriever`` interface. + Combine with ``SectionTreeRetriever`` (and others) in the ensemble to + blend complementary signals. + """ + + name = "kg_node" + + def __init__( + self, + graph: nx.Graph, + kg_chunks: dict[int, str], + neighbor_weight: float = 0.5, + num_hops: int = 1, + canonical_lookup: CanonicalLookup | None = None, + ): + self.graph = graph + self.kg_chunks = kg_chunks + self.neighbor_weight = neighbor_weight + self.num_hops = num_hops + self.canonical_lookup = canonical_lookup + + def get_scores(self, query: str, pool_size: int, chunks: list) -> dict[int, float]: + """Return BFS-based relevance scores keyed by global chunk index. + + Args: + query: Natural-language query string. + pool_size: Maximum number of chunks to return scores for. + chunks: The RAG pipeline's chunk list (used only for length). + + Returns: + ``Dict[chunk_id, score]`` normalized to [0, 1]. + Returns an empty dict if no query nodes match the graph. + """ + query_nodes = extract_query_nodes( + query, self.graph, self.canonical_lookup) + logger.debug("Query: %r", query) + logger.debug("Matched query nodes (%d): %s", + len(query_nodes), query_nodes) + if not query_nodes: + logger.debug("No query nodes matched — returning empty.") + return {} + + max_edge_weight = max( + (data["weight"] for _, _, data in self.graph.edges(data=True)), + default=1, + ) + max_edge_weight = max(max_edge_weight, 1) + logger.debug("Max edge weight in graph: %s", max_edge_weight) + + scores: dict[int, float] = {} + + # Hop 0: directly matched query nodes + for node in query_nodes: + for chunk_id in self.graph.nodes[node].get("chunk_ids", []): + scores[chunk_id] = scores.get(chunk_id, 0.0) + 1.0 + + # BFS over hops 1..num_hops; each node is visited only at its closest hop + visited: set[str] = set(query_nodes) + frontier: set[str] = set(query_nodes) + + for hop in range(1, self.num_hops + 1): + decay = self.neighbor_weight ** hop + next_frontier: set[str] = set() + for node in frontier: + for neighbor in self.graph.neighbors(node): + if neighbor in visited: + continue + next_frontier.add(neighbor) + edge_weight = self.graph[node][neighbor].get("weight", 1) + contribution = decay * (edge_weight / max_edge_weight) + for chunk_id in self.graph.nodes[neighbor].get("chunk_ids", []): + scores[chunk_id] = scores.get( + chunk_id, 0.0) + contribution + visited |= next_frontier + frontier = next_frontier + logger.debug("Hop %d: %d new node(s) explored.", + hop, len(next_frontier)) + if not frontier: + break + + if not scores: + logger.debug("No chunks scored — returning empty.") + return {} + + max_score = max(scores.values()) + if max_score <= 0: + logger.debug("Max score is %s — returning empty.", max_score) + return {} + + normalized = {cid: s / max_score for cid, s in scores.items()} + logger.debug( + "Normalized scores: %s", + dict(sorted(normalized.items(), key=lambda x: x[1], reverse=True)), + ) + return normalized + + +class SectionTreeRetriever(Retriever): + """Retriever that scores chunks based on section-heading relevance. + + Uses ``SectionTree.get_chunk_scores`` which blends: + - Heading keyword overlap (structural signal). + - KG keyword overlap aggregated from the graph (lexical signal). + - Top-down score inheritance from parent sections to children. + + Plugs into ``EnsembleRanker`` via the standard ``Retriever`` interface. + Combine with ``KGNodeRetriever`` (and others) in the ensemble to blend + complementary signals. + """ + + name = "section_tree" + + def __init__( + self, + section_tree: SectionTree, + graph: nx.Graph, + canonical_lookup: CanonicalLookup | None = None, + heading_alpha: float = 0.5, + inheritance_decay: float = 0.5, + ): + self.section_tree = section_tree + self.graph = graph + self.canonical_lookup = canonical_lookup + self.heading_alpha = heading_alpha + self.inheritance_decay = inheritance_decay + + def get_scores(self, query: str, pool_size: int, chunks: list) -> dict[int, float]: + """Return section-relevance scores keyed by global chunk index. + + Args: + query: Natural-language query string. + pool_size: Maximum number of chunks to return scores for. + chunks: The RAG pipeline's chunk list (unused; present for interface compat). + + Returns: + ``Dict[chunk_id, score]`` normalized to [0, 1]. + """ + query_keywords = set(extract_query_nodes( + query, self.graph, self.canonical_lookup)) + return self.section_tree.get_chunk_scores( + query_keywords, + query=query, + heading_alpha=self.heading_alpha, + inheritance_decay=self.inheritance_decay, + ) + + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Test the KG node retriever.") + parser.add_argument( + "output_dir", + nargs="?", + default=RUNS_DIR, + help="Run directory or runs/ parent (default: latest run).", + ) + parser.add_argument("--query", default="What is SQL?") + parser.add_argument("--top_k", type=int, default=10) + parser.add_argument("--neighbor_weight", type=float, default=0.5) + parser.add_argument("--num_hops", type=int, default=1) + args = parser.parse_args() + + _graph, _chunks = load_graph_and_chunks(args.output_dir) + _retriever = KGNodeRetriever( + _graph, _chunks, + neighbor_weight=args.neighbor_weight, + num_hops=args.num_hops, + ) + _scores = _retriever.get_scores( + args.query, args.top_k, list(_chunks.values())) + _results = sorted( + [(cid, _chunks[cid], score) + for cid, score in _scores.items() if cid in _chunks], + key=lambda x: x[2], reverse=True, + )[:args.top_k] + + print(f"\nTop {len(_results)} results for query: {args.query!r}\n") + for i, (chunk_id, chunk_text, score) in enumerate(_results, 1): + print(f"{i}. Chunk ID: {chunk_id}, Score: {score:.4f}") + print(f" Text: {chunk_text[:200]}...\n") diff --git a/src/knowledge_graph/scripts/analyze_query.py b/src/knowledge_graph/scripts/analyze_query.py new file mode 100644 index 00000000..457f31a9 --- /dev/null +++ b/src/knowledge_graph/scripts/analyze_query.py @@ -0,0 +1,34 @@ +import json +import argparse +import os +import logging + +from src.knowledge_graph.analysis import analyze_query +from src.knowledge_graph.io import RUNS_DIR, load_graph +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Analyze query difficulty against a Knowledge Graph." + ) + parser.add_argument( + "--graph", + default=os.path.join(RUNS_DIR, "latest", "graph.json"), + help="Path to the NetworkX JSON graph file (default: latest run).", + ) + parser.add_argument("--query", required=True, help="The query string to analyze.") + parser.add_argument("--debug", action="store_true", help="Print debug information during analysis.") + args = parser.parse_args() + + if args.debug: + logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(name)s %(levelname)s %(message)s") + + graph = load_graph(args.graph) + logger.debug(f"Loaded graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges.") + result = analyze_query(args.query, graph) + print(json.dumps(result.to_dict(), indent=2)) + + +if __name__ == "__main__": + main() diff --git a/src/knowledge_graph/scripts/benchmark_retrieval.py b/src/knowledge_graph/scripts/benchmark_retrieval.py new file mode 100644 index 00000000..df536230 --- /dev/null +++ b/src/knowledge_graph/scripts/benchmark_retrieval.py @@ -0,0 +1,369 @@ +import argparse +import json +import logging +import os + +import yaml +from dotenv import load_dotenv + +from src.knowledge_graph.build import RUNS_DIR +from src.knowledge_graph.io import ( + load_canonicalization_data, + load_graph_chunks_and_tree, + resolve_run_dir, +) +from src.knowledge_graph.openrouter_client import OpenRouterClient +from src.knowledge_graph.query import ( + CanonicalLookup, + KGNodeRetriever, + SectionTreeRetriever, +) +from src.knowledge_graph.prompts import GRADE_PROMPT +from src.retriever import BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts + +logger = logging.getLogger(__name__) + + +def _grade_with_llm( + client: OpenRouterClient, + model: str, + query: str, + retrieved: list[tuple[int, str, float]], +) -> list[dict]: + passages = "\n\n".join( + f"[{i + 1}] {text[:600].strip()}" + for i, (_, text, _) in enumerate(retrieved) + ) + prompt = GRADE_PROMPT.format(query=query, passages=passages) + raw = client.chat( + model, + [{"role": "user", "content": prompt}], + response_format={"type": "json_object"}, + ) + grades = json.loads(raw).get("grades", []) + + results = [] + for i, (chunk_id, _, _) in enumerate(retrieved): + grade = next((g for g in grades if g.get("id") == i + 1), {}) + results.append( + { + "chunk_id": chunk_id, + "score": int(grade["score"]) if "score" in grade else -1, + "reason": grade.get("reason", ""), + } + ) + return results + + +def _llm_metrics(grades: list[dict], top_k: int) -> dict: + scored = [g["score"] for g in grades if g["score"] >= 0] + if not scored: + return {} + relevant = sum(1 for s in scored if s >= 1) + return { + # Fraction of the top-k chunks judged relevant (score >= 1) by the LLM. + "precision_at_k": relevant / top_k, + # Average raw LLM relevance score across retrieved chunks (0=irrelevant, 1=partial, 2=relevant). + "mean_relevance_score": sum(scored) / len(scored), + } + + +def run_benchmark( + run_dir: str, + queries: list[dict], + top_k: int = 5, + llm_client: OpenRouterClient | None = None, + llm_model: str = "openai/gpt-4o-mini", + num_hops: int = 1, + neighbor_weight: float = 0.5, + artifacts_dir: str | None = None, + index_prefix: str = "textbook_index", + embed_model: str = "", + extracted_index_path: str = "data/extracted_index.json", + page_to_chunk_map_path: str = "index/sections/textbook_index_page_to_chunk_map.json", +) -> list[dict]: + """Run retrieval benchmark for all queries across all available retrievers.""" + kg_graph, kg_chunks, tree = load_graph_chunks_and_tree(run_dir) + + resolved = resolve_run_dir(run_dir) + syn_table, can_kw, can_emb = load_canonicalization_data(resolved) + canonical_lookup = ( + CanonicalLookup(syn_table, can_kw, can_emb) if syn_table is not None else None + ) + + # Unified chunk lookup: RAG list takes precedence (dict-wrapped), KG dict as fallback. + chunks: dict[int, str] = kg_chunks + retrievers = [] + + if artifacts_dir: + try: + faiss_idx, bm25_idx, rag_chunks, _, _ = load_artifacts(artifacts_dir, index_prefix) + chunks = {i: t for i, t in enumerate(rag_chunks)} + + if embed_model: + retrievers.append(FAISSRetriever(faiss_idx, embed_model)) + logger.info("FAISSRetriever enabled.") + else: + logger.info("Skipping FAISSRetriever: --embed-model not provided.") + + retrievers.append(BM25Retriever(bm25_idx)) + logger.info("BM25Retriever enabled.") + + if os.path.exists(extracted_index_path) and os.path.exists(page_to_chunk_map_path): + retrievers.append(IndexKeywordRetriever(extracted_index_path, page_to_chunk_map_path)) + logger.info("IndexKeywordRetriever enabled.") + except (FileNotFoundError, RuntimeError) as e: + logger.warning("RAG artifacts not found, skipping FAISS/BM25: %s", e) + + retrievers.append( + KGNodeRetriever( + kg_graph, + kg_chunks, + neighbor_weight=neighbor_weight, + num_hops=num_hops, + canonical_lookup=canonical_lookup, + ) + ) + + if tree is not None: + retrievers.append(SectionTreeRetriever(tree, kg_graph, canonical_lookup=canonical_lookup)) + logger.info("SectionTreeRetriever enabled.") + else: + logger.info("No section tree found — SectionTreeRetriever skipped.") + + results = [] + for q in queries: + qid = q.get("id", "unknown") + query_text = q.get("question", q.get("query", "")) + + print(f"\n[{qid}] {query_text}") + + retriever_results: dict[str, dict] = {} + for retriever in retrievers: + scores = retriever.get_scores(query_text, top_k, list(chunks.values())) + retrieved = sorted( + [(cid, chunks[cid], score) for cid, score in scores.items() if cid in chunks], + key=lambda x: x[2], + reverse=True, + )[:top_k] + if not retrieved: + print(f" [{retriever.name}] WARNING: no chunks retrieved") + + llm_grades = None + llm_m = None + if llm_client and retrieved: + try: + llm_grades = _grade_with_llm(llm_client, llm_model, query_text, retrieved) + llm_m = _llm_metrics(llm_grades, top_k) + print( + f" [{retriever.name}] LLM " + f"P@{top_k}={llm_m.get('precision_at_k', 0):.2f} " + f"mean_score={llm_m.get('mean_relevance_score', 0):.2f}" + ) + except Exception as e: + logger.warning( + "LLM grading failed for %r / %r: %s", qid, retriever.name, e + ) + + retrieved_list = [] + for chunk_id, text, score in retrieved: + entry: dict = { + "chunk_id": chunk_id, + "score": round(score, 4), + "text_preview": text[:200], + } + if llm_grades: + grade = next((g for g in llm_grades if g["chunk_id"] == chunk_id), {}) + entry["llm_score"] = grade.get("score") + entry["llm_reason"] = grade.get("reason", "") + retrieved_list.append(entry) + + retriever_results[retriever.name] = { + "retrieved": retrieved_list, + "llm_metrics": llm_m, + } + + results.append( + { + "id": qid, + "query": query_text, + "retrievers": retriever_results, + } + ) + + return results + + +def _avg(values: list[float]) -> float | None: + clean = [v for v in values if v is not None] + return sum(clean) / len(clean) if clean else None + + +def print_summary(results: list[dict], top_k: int) -> None: + retriever_names: list[str] = [] + for r in results: + for name in r.get("retrievers", {}): + if name not in retriever_names: + retriever_names.append(name) + + col_id = 30 + col_w = 16 # width per retriever: "P@k Mean" each 7 chars + spacing + + # Header row: Query ID + two sub-columns (P@k, Mean) per retriever + header1 = f"{'Query ID':<{col_id}}" + header2 = " " * col_id + for name in retriever_names: + short = name[:col_w].center(col_w) + header1 += f" {short}" + sub = f"{'P@'+str(top_k):>6} {'Mean':>6}" + header2 += f" {sub}" + + sep = "-" * len(header1) + print(f"\n{sep}") + print(header1) + print(header2) + print(sep) + + # Accumulators for averages + agg: dict[str, dict[str, list[float]]] = {n: {"p": [], "m": []} for n in retriever_names} + + for r in results: + row = f"{r['id'][:col_id]:<{col_id}}" + for name in retriever_names: + lm = r.get("retrievers", {}).get(name, {}).get("llm_metrics") or {} + if lm: + p = lm.get("precision_at_k", 0.0) + m = lm.get("mean_relevance_score", 0.0) + agg[name]["p"].append(p) + agg[name]["m"].append(m) + row += f" {p:>6.2f} {m:>6.2f}" + else: + row += f" {'—':>6} {'—':>6}" + print(row) + + print(sep) + avg_row = f"{'AVERAGE':<{col_id}}" + for name in retriever_names: + a_p = _avg(agg[name]["p"]) + a_m = _avg(agg[name]["m"]) + p_str = f"{a_p:.2f}" if a_p is not None else "—" + m_str = f"{a_m:.2f}" if a_m is not None else "—" + avg_row += f" {p_str:>6} {m_str:>6}" + print(avg_row) + print(sep) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark all retrievers against a query set.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--run-dir", + default=RUNS_DIR, + help="KG run directory or runs/ parent with 'latest' symlink", + ) + parser.add_argument( + "--queries", + default="tests/benchmarks.yaml", + help="YAML file with query list", + ) + parser.add_argument("--top-k", type=int, default=5) + parser.add_argument( + "--model", + default="openai/gpt-4o-mini", + help="OpenRouter model for LLM grading", + ) + parser.add_argument( + "--api-key", + default=None, + help="OpenRouter API key (falls back to OPENROUTER_API_KEY env var)", + ) + parser.add_argument( + "--output", + default=None, + help="Write full results to this JSON file", + ) + parser.add_argument( + "--no-llm", + action="store_true", + help="Skip LLM relevance grading", + ) + parser.add_argument("--num-hops", type=int, default=1) + parser.add_argument("--neighbor-weight", type=float, default=0.5) + parser.add_argument( + "--artifacts-dir", + default=None, + help="RAG artifacts directory for FAISS/BM25 retrievers (e.g. index/recursive_sections/)", + ) + parser.add_argument( + "--index-prefix", + default="textbook_index", + help="Index artifact prefix used when building the RAG index", + ) + parser.add_argument( + "--embed-model", + default="", + help="Embedding model path for FAISSRetriever (GGUF or HuggingFace name)", + ) + parser.add_argument( + "--extracted-index", + default="data/extracted_index.json", + help="Path to extracted_index.json for IndexKeywordRetriever", + ) + parser.add_argument( + "--page-chunk-map", + default="index/sections/textbook_index_page_to_chunk_map.json", + help="Path to page_to_chunk_map.json for IndexKeywordRetriever", + ) + parser.add_argument("-v", "--verbose", action="store_true") + args = parser.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.WARNING, + format="%(levelname)s %(name)s: %(message)s", + ) + + with open(args.queries) as f: + data = yaml.safe_load(f) + queries = data.get("benchmarks", data.get("queries", [])) + print(f"Loaded {len(queries)} queries from {args.queries}") + + llm_client = None + if not args.no_llm: + api_key = args.api_key or os.environ.get("OPENROUTER_API_KEY") + if api_key: + llm_client = OpenRouterClient(api_key, retries=2) + print(f"LLM grading enabled: {args.model}") + else: + print( + "No API key found — running without LLM grading. " + "Pass --api-key or set OPENROUTER_API_KEY." + ) + + results = run_benchmark( + run_dir=args.run_dir, + queries=queries, + top_k=args.top_k, + llm_client=llm_client, + llm_model=args.model, + num_hops=args.num_hops, + neighbor_weight=args.neighbor_weight, + artifacts_dir=args.artifacts_dir, + index_prefix=args.index_prefix, + embed_model=args.embed_model, + extracted_index_path=args.extracted_index, + page_to_chunk_map_path=args.page_chunk_map, + ) + + print_summary(results, args.top_k) + + if args.output: + with open(args.output, "w") as f: + json.dump(results, f, indent=2, default=lambda o: int(o) if hasattr(o, "__index__") else str(o)) + print(f"\nFull results written to {args.output}") + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/src/knowledge_graph/scripts/generate_canon_cache.py b/src/knowledge_graph/scripts/generate_canon_cache.py new file mode 100644 index 00000000..660587ef --- /dev/null +++ b/src/knowledge_graph/scripts/generate_canon_cache.py @@ -0,0 +1,93 @@ +import argparse +import json +import logging +import os + +from dotenv import load_dotenv + +from src.knowledge_graph.build import ( + CHUNKS_PKL, + JSON_KW_PATH, + META_PKL, + PROJECT_ROOT, + load_chunks, +) +from src.knowledge_graph.canonicalizer import Canonicalizer +from src.knowledge_graph.extractors import JsonExtractor +from src.knowledge_graph.scripts.run_kg_pipeline import KGPipelineConfig + +logger = logging.getLogger(__name__) + +DEFAULT_CACHE_PATH = os.path.join(PROJECT_ROOT, "debug", "canonicalization_cache.json") + + +def main() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s %(message)s", + ) + + parser = argparse.ArgumentParser( + description="Run LLM canonicalization once and save the full result to a cache file." + ) + parser.add_argument( + "--config", + default=os.path.join(PROJECT_ROOT, "config", "config.yaml"), + help="Path to project config YAML (default: config/config.yaml)", + ) + parser.add_argument( + "--output", + default=DEFAULT_CACHE_PATH, + help=f"Path to write the cache JSON (default: {DEFAULT_CACHE_PATH})", + ) + args = parser.parse_args() + + cfg = KGPipelineConfig.from_yaml(args.config) + logger.info("Loaded config from %s", args.config) + + api_key = os.environ.get("OPENROUTER_API_KEY", "") + if not api_key: + raise EnvironmentError("OPENROUTER_API_KEY environment variable must be set.") + + logger.info("Loading chunks from:\n %s\n %s", CHUNKS_PKL, META_PKL) + chunks = load_chunks(CHUNKS_PKL, META_PKL) + logger.info("Loaded %d chunks", len(chunks)) + + extractor = JsonExtractor(input_path=JSON_KW_PATH) + extractions = extractor.extract(chunks) + logger.info("Extracted %d results", len(extractions)) + + c = cfg.canonicalization + canonicalizer = Canonicalizer( + embedding_model=c.embed_model, + corpus_description=cfg.corpus_description, + api_key=api_key, + llm_model=c.llm_model, + similarity_threshold=c.similarity_threshold, + max_group_size=c.max_group_size, + batch_size=c.batch_size, + ) + + updated_extractions, canon_result = canonicalizer.canonicalize(extractions) + + cache = { + "updated_extractions": [ + {"chunk_id": e.chunk_id, "keywords": e.keywords} + for e in updated_extractions + ], + "synonym_table": canon_result.synonym_table, + "canonical_keywords": canon_result.canonical_keywords, + "canonical_embeddings": canon_result.canonical_embeddings.tolist(), + "stats": canon_result.stats, + } + + os.makedirs(os.path.dirname(args.output), exist_ok=True) + with open(args.output, "w", encoding="utf-8") as f: + json.dump(cache, f, indent=2, ensure_ascii=False) + + logger.info("Saved cache to %s", args.output) + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/src/knowledge_graph/llm_extract_keywords.py b/src/knowledge_graph/scripts/llm_extract_keywords.py similarity index 100% rename from src/knowledge_graph/llm_extract_keywords.py rename to src/knowledge_graph/scripts/llm_extract_keywords.py diff --git a/src/knowledge_graph/run_kg_pipeline.py b/src/knowledge_graph/scripts/run_kg_pipeline.py similarity index 76% rename from src/knowledge_graph/run_kg_pipeline.py rename to src/knowledge_graph/scripts/run_kg_pipeline.py index 92551487..35006106 100644 --- a/src/knowledge_graph/run_kg_pipeline.py +++ b/src/knowledge_graph/scripts/run_kg_pipeline.py @@ -1,3 +1,5 @@ +from src.knowledge_graph.section_tree import build_section_tree, save_section_tree +from src.knowledge_graph.canonicalizer import Canonicalizer import argparse import logging import os @@ -12,6 +14,9 @@ setup_input_dir, write_config, update_latest_symlink, + load_chunks, + META_PKL, + CHUNKS_PKL, ) from src.knowledge_graph.linkers import CooccurrenceLinker @@ -138,15 +143,47 @@ def main() -> None: chapter_filter = f"Chapter {args.chapter} " if args.chapter else None exclude_chapters = [f"Chapter {c} " for c in args.exclude_chapters] + c = cfg.canonicalization + + canonicalizer = Canonicalizer( + embedding_model=c.embed_model, + corpus_description=cfg.corpus_description, + api_key=args.api_key or os.environ.get("OPENROUTER_API_KEY", ""), + llm_model=c.llm_model, + similarity_threshold=c.similarity_threshold, + max_group_size=c.max_group_size, + batch_size=c.batch_size, + ) + linker = CooccurrenceLinker(min_cooccurrence=cfg.min_cooccurrence) - build_kg( + + chunks = load_chunks( + CHUNKS_PKL, + META_PKL, + chapter_filter=chapter_filter, + exclude_chapters=exclude_chapters, + ) + logger.info("Loaded %d chunks", len(chunks)) + graph = build_kg( output_dir=run_dir, + chunks=chunks, extractor=extractor, linker=linker, - chapter_filter=chapter_filter, - exclude_chapters=exclude_chapters or None, + canonicalizer=canonicalizer, ) + logger.info("Building section tree...") + tree = build_section_tree(chunks, graph) + tree_path = save_section_tree(tree, run_dir) + level_counts: dict[int, int] = {} + for node in tree.node_index.values(): + level_counts[node.level] = level_counts.get(node.level, 0) + 1 + level_labels = {1: "chapters", 2: "sections", 3: "subsections"} + for level, count in sorted(level_counts.items()): + label = level_labels.get(level, f"level-{level} nodes") + logger.info(" %4d %s", count, label) + logger.info(" Saved: %s", tree_path) + update_latest_symlink(run_dir) logger.info("Updated: %s -> %s", os.path.join(RUNS_DIR, "latest"), run_dir) diff --git a/src/knowledge_graph/section_tree.py b/src/knowledge_graph/section_tree.py new file mode 100644 index 00000000..24a54e43 --- /dev/null +++ b/src/knowledge_graph/section_tree.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +import json +import os +import re +from dataclasses import dataclass, field +from typing import Optional + +import networkx as nx + +from src.knowledge_graph.models import Chunk +from src.knowledge_graph.normalizer import Normalizer +from src.knowledge_graph.ngrams import extract_ngrams, HEADING_PATTERN, KW_PATTERN + +_NUMBER_RE = re.compile(r"(\d+(?:\.\d+)*)") + +# Tokens to strip from heading text before building heading_keywords +_HEADING_PREFIX_RE = re.compile(r"\b(section|chapter)\b", re.IGNORECASE) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _extract_section_number(heading: str) -> str | None: + """Return the section number from a heading like 'Section 13.1 ...'.""" + m = _NUMBER_RE.search(heading) + return m.group(1) if m else None + + +def _parent_number(number: str) -> str | None: + """Return the parent section number, or None for a top-level number.""" + parts = number.split(".") + return ".".join(parts[:-1]) if len(parts) > 1 else None + + +def _build_heading_keywords(heading: str) -> set[str]: + """Tokenize a section heading into a normalized keyword set. + + Strips the section number and "Section"/"Chapter" prefixes, then + produces normalized unigrams, bigrams, and trigrams from the + remaining words — matching the n-gram strategy used for KG nodes. + """ + text = _NUMBER_RE.sub("", heading) + text = _HEADING_PREFIX_RE.sub("", text).strip() + return extract_ngrams(text, HEADING_PATTERN) + + +def _tokenize_query(query: str) -> set[str]: + """Extract normalized unigrams, bigrams, and trigrams from a raw query. + + Unlike ``extract_query_nodes``, this does **not** filter against the KG + graph — all normalized query tokens are returned. + """ + return extract_ngrams(query, KW_PATTERN) + + +# ── Data model ──────────────────────────────────────────────────────────────── + + +@dataclass +class SectionNode: + heading: str # e.g. "Section 13.1 Physical Storage Media" + level: int # 1 = chapter, 2 = section, 3 = subsection + chapter: int # e.g. 13 + section_number: str # e.g. "13.1" + chunk_ids: list[int] = field(default_factory=list) + keyword_set: set[str] = field(default_factory=set) + children: list[SectionNode] = field(default_factory=list) + parent: Optional[SectionNode] = field(default=None, repr=False, compare=False) + heading_keywords: set[str] = field(default_factory=set, repr=False, compare=False) + + +class SectionTree: + """Tree mirroring the textbook's heading hierarchy with aggregated KG keywords.""" + + def __init__(self, root: SectionNode) -> None: + self.root = root + self.node_index: dict[str, SectionNode] = {} # heading → node + self._number_index: dict[str, SectionNode] = {} # section_number → node + self.chunk_to_sections: dict[ + int, list[SectionNode] + ] = {} # chunk_id → leaf nodes + + # ── Index helpers ───────────────────────────────────────────────────────── + + def _register(self, node: SectionNode) -> None: + self.node_index[node.heading] = node + self._number_index[node.section_number] = node + + def get_nodes_at_level(self, level: int) -> list[SectionNode]: + return [n for n in self.node_index.values() if n.level == level] + + # ── Query-time scoring ──────────────────────────────────────────────────── + + def _score_section_kg( + self, + node: SectionNode, + query_keywords: set[str], + alpha: float = 0.6, + ) -> float: + """KG keyword overlap score: coverage × alpha + specificity × (1 - alpha). + + Coverage: fraction of query keywords present in the section. + Specificity: fraction of the section's keywords that are query keywords. + """ + if not node.keyword_set or not query_keywords: + return 0.0 + matched = query_keywords & node.keyword_set + if not matched: + return 0.0 + coverage = len(matched) / len(query_keywords) + specificity = len(matched) / len(node.keyword_set) + return alpha * coverage + (1 - alpha) * specificity + + def _score_section_heading( + self, + node: SectionNode, + query_tokens: set[str], + alpha: float = 0.6, + ) -> float: + """Heading keyword overlap score: coverage × alpha + specificity × (1 - alpha). + + Matches independently-tokenized query tokens against the pre-built + heading keyword set. Uses the same formula as ``_score_section_kg`` + for a consistent scale. + """ + if not node.heading_keywords or not query_tokens: + return 0.0 + matched = query_tokens & node.heading_keywords + if not matched: + return 0.0 + coverage = len(matched) / len(query_tokens) + specificity = len(matched) / len(node.heading_keywords) + return alpha * coverage + (1 - alpha) * specificity + + def get_all_descendant_chunk_ids(self, node: SectionNode) -> list[int]: + ids: list[int] = list(node.chunk_ids) + for child in node.children: + ids.extend(self.get_all_descendant_chunk_ids(child)) + return ids + + def get_chunk_scores( + self, + query_keywords: set[str], + query: str | None = None, + heading_alpha: float = 0.5, + inheritance_decay: float = 0.5, + alpha: float = 0.6, + ) -> dict[int, float]: + """Return chunk_id → normalized section-relevance score. + + Hybrid scoring blends two independent signals per section node: + + - **Heading keyword match** (structural): overlap between + independently-tokenized query tokens and the pre-built heading + keyword set. Captures queries phrased differently from the KG + vocabulary; independent of which terms exist as KG nodes. + - **KG keyword overlap** (lexical): coverage × alpha + specificity × + (1 - alpha) using the node's aggregated KG keyword set. + + ``heading_alpha`` controls the blend (1.0 = heading-only, 0.0 = + KG-only). Falls back to KG-only when ``query`` is None or heading + keywords are absent. + + **Top-down inheritance** propagates a parent's score to its children: + + effective(node) = own_score(node) + inheritance_decay × effective(parent) + + This ensures that if section 13.1 is highly relevant, its subsections + 13.1.1, 13.1.2, … receive a proportional boost even if they score + lower on their own. Each chunk gets the effective score of its direct + section node; chunks in more specific subsections that also match are + doubly reinforced. + + Final scores are normalized to [0, 1]. + """ + if not self.node_index: + return {} + + # Tokenize raw query independently for heading matching + query_tokens: set[str] = set() + if query is not None: + normalizer = Normalizer() + query_tokens = _tokenize_query(query) + + # ── Step 1: Compute own score for every node ────────────────────────── + own_scores: dict[str, float] = {} + for heading, node in self.node_index.items(): + kg_score = self._score_section_kg(node, query_keywords, alpha) + + if query_tokens and node.heading_keywords: + heading_score = self._score_section_heading(node, query_tokens, alpha) + own_scores[heading] = ( + heading_alpha * heading_score + (1 - heading_alpha) * kg_score + ) + else: + own_scores[heading] = kg_score + + # ── Step 2: Top-down DFS — effective = own + decay × parent_effective ─ + effective: dict[str, float] = {} + + def _propagate(node: SectionNode, parent_eff: float) -> None: + own = own_scores.get(node.heading, 0.0) + eff = own + inheritance_decay * parent_eff + effective[node.heading] = eff + for child in node.children: + _propagate(child, eff) + + for top_level in self.root.children: + _propagate(top_level, 0.0) + + # ── Step 3: Assign chunk scores from their direct section node ──────── + chunk_scores: dict[int, float] = {} + for heading, node in self.node_index.items(): + eff = effective.get(heading, 0.0) + if eff <= 0.0: + continue + for chunk_id in node.chunk_ids: + chunk_scores[chunk_id] = max(chunk_scores.get(chunk_id, 0.0), eff) + + if not chunk_scores: + return {} + + max_score = max(chunk_scores.values()) + if max_score > 0: + chunk_scores = {cid: s / max_score for cid, s in chunk_scores.items()} + return chunk_scores + + # ── Serialization ───────────────────────────────────────────────────────── + + def to_dict(self) -> dict: + def node_to_dict(n: SectionNode) -> dict: + return { + "heading": n.heading, + "level": n.level, + "chapter": n.chapter, + "section_number": n.section_number, + "chunk_ids": n.chunk_ids, + "keyword_set": sorted(n.keyword_set), + "heading_keywords": sorted(n.heading_keywords), + "children": [node_to_dict(c) for c in n.children], + } + + return node_to_dict(self.root) + + @classmethod + def from_dict(cls, data: dict) -> SectionTree: + def dict_to_node(d: dict, parent: SectionNode | None) -> SectionNode: + node = SectionNode( + heading=d["heading"], + level=d["level"], + chapter=d["chapter"], + section_number=d["section_number"], + chunk_ids=d["chunk_ids"], + keyword_set=set(d["keyword_set"]), + heading_keywords=set(d.get("heading_keywords", [])), + parent=parent, + ) + node.children = [dict_to_node(c, node) for c in d.get("children", [])] + return node + + root = dict_to_node(data, None) + tree = cls(root) + tree._rebuild_indexes(root) + return tree + + def _rebuild_indexes(self, node: SectionNode) -> None: + if node.heading != "root": + self._register(node) + for chunk_id in node.chunk_ids: + self.chunk_to_sections.setdefault(chunk_id, []).append(node) + for child in node.children: + self._rebuild_indexes(child) + + +# ── Build ───────────────────────────────────────────────────────────────────── + + +def build_section_tree( + chunks: list[Chunk], + graph: nx.Graph, +) -> SectionTree: + """Build a SectionTree from KG chunks and a populated knowledge graph. + + Steps: + 1. Collect unique sections from chunk metadata (heading, level, chapter). + 2. Attach each section node to its parent using the section number prefix + (e.g. "13.1" → parent "13"). + 3. Assign chunk_ids to their leaf section nodes. + 4. Populate leaf keyword_sets from the graph's ``chunk_ids`` node attributes. + 5. Aggregate keyword sets bottom-up so every ancestor contains the union of + all descendant keywords. + 6. Extract heading_keywords for each section using the Normalizer. + + Args: + chunks: Chunk objects with a ``section`` metadata field containing the + immediate heading string, e.g. ``"Section 1.1 Foo Bar"`` + (produced by ``index_builder.build_index``). ``level`` and + ``chapter`` are derived from the section number via regex. + graph: NetworkX graph from the KG pipeline; each node has a + ``chunk_ids`` attribute listing which chunks contain it. + + Returns: + A fully populated ``SectionTree``. + """ + root = SectionNode(heading="root", level=0, chapter=0, section_number="") + tree = SectionTree(root) + + # ── Step 1: Collect unique sections ────────────────────────────────────── + seen: dict[str, SectionNode] = {} # section_number → SectionNode + for chunk in chunks: + meta = chunk.metadata + heading = meta.get("section", "") + if not heading: + continue + section_number = _extract_section_number(heading) + if section_number is None: + continue + if section_number not in seen: + level = section_number.count(".") + 1 + chapter = int(section_number.split(".")[0]) + seen[section_number] = SectionNode( + heading=heading, + level=level, + chapter=chapter, + section_number=section_number, + ) + + # ── Step 2: Build tree structure (shortest numbers first = parents first) ─ + for section_number, node in sorted( + seen.items(), key=lambda x: (x[0].count("."), x[0]) + ): + parent_num = _parent_number(section_number) + parent_node = seen.get(parent_num, root) if parent_num else root + node.parent = parent_node + parent_node.children.append(node) + tree._register(node) + + # ── Step 3: Assign chunk_ids to leaf nodes ──────────────────────────────── + for chunk in chunks: + meta = chunk.metadata + section_number = _extract_section_number(meta.get("section", "")) + if not section_number or section_number not in seen: + continue + leaf = seen[section_number] + if chunk.id not in leaf.chunk_ids: + leaf.chunk_ids.append(chunk.id) + tree.chunk_to_sections.setdefault(chunk.id, []).append(leaf) + + # ── Step 4: Populate keyword sets from KG graph ─────────────────────────── + for kg_node_name, kg_node_data in graph.nodes(data=True): + for chunk_id in kg_node_data.get("chunk_ids", []): + for leaf in tree.chunk_to_sections.get(chunk_id, []): + leaf.keyword_set.add(kg_node_name) + + # ── Step 5: Bottom-up keyword aggregation ───────────────────────────────── + def _aggregate(node: SectionNode) -> None: + for child in node.children: + _aggregate(child) + node.keyword_set |= child.keyword_set + + _aggregate(root) + + # ── Step 6: Extract heading keywords for each section ───────────────────── + for node in seen.values(): + node.heading_keywords = _build_heading_keywords(node.heading) + + return tree + + +# ── Persist / load ──────────────────────────────────────────────────────────── + + +def save_section_tree(tree: SectionTree, run_dir: str) -> str: + """Serialize *tree* to ``section_tree.json`` inside *run_dir*. + + Returns: + The full path of the written file. + """ + os.makedirs(run_dir, exist_ok=True) + path = os.path.join(run_dir, "section_tree.json") + with open(path, "w", encoding="utf-8") as f: + json.dump(tree.to_dict(), f, indent=2, ensure_ascii=False) + return path + + +def load_section_tree(run_dir: str) -> SectionTree: + """Load the section tree from ``section_tree.json`` in *run_dir*. + + Raises: + FileNotFoundError: If ``section_tree.json`` is not found. + """ + path = os.path.join(run_dir, "section_tree.json") + if not os.path.isfile(path): + raise FileNotFoundError(f"No section_tree.json found in {run_dir!r}") + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return SectionTree.from_dict(data) diff --git a/src/main.py b/src/main.py index 5846a5e5..0f5e1c4a 100644 --- a/src/main.py +++ b/src/main.py @@ -27,6 +27,8 @@ get_page_numbers, load_artifacts ) +from src.knowledge_graph.io import load_graph_chunks_and_tree +from src.knowledge_graph.query import KGRetriever from src.ranking.reranker import rerank from src.cache import get_cache @@ -384,7 +386,16 @@ def run_chat_session(args: argparse.Namespace, cfg: RAGConfig): retrievers = [FAISSRetriever(faiss_idx, cfg.embed_model), BM25Retriever(bm25_idx)] if cfg.ranker_weights.get("index_keywords", 0) > 0: retrievers.append(IndexKeywordRetriever(cfg.extracted_index_path, cfg.page_to_chunk_map_path)) - + # Add knowledge graph retriever if weight > 0 and graph dir is configured + if cfg.ranker_weights.get("kg", 0) > 0 and cfg.kg_graph_dir: + kg_graph, kg_chunks, kg_tree = load_graph_chunks_and_tree(cfg.kg_graph_dir) + retrievers.append(KGRetriever( + kg_graph, kg_chunks, + section_tree=kg_tree, + beta=cfg.kg_beta, + heading_alpha=cfg.kg_heading_alpha, + inheritance_decay=cfg.kg_inheritance_decay, + )) ranker = EnsembleRanker(ensemble_method=cfg.ensemble_method, weights=cfg.ranker_weights, rrf_k=int(cfg.rrf_k)) print("Loaded retrievers and initialized ranker.") artifacts = {"chunks": chunks, "sources": sources, "retrievers": retrievers, "ranker": ranker, "meta": meta}