diff --git a/config/category_prototypes.yaml b/config/category_prototypes.yaml new file mode 100644 index 00000000..a2e36531 --- /dev/null +++ b/config/category_prototypes.yaml @@ -0,0 +1,55 @@ +# Category prototypes for PrototypeClassifier +# --------------------------------------------- +# Each list is averaged into a single prototype embedding at planner-build +# time. Exemplars are sourced from eval/questions.jsonl so the prototype +# space matches the distribution we evaluate against. +# +# Add 3-5 exemplars per category. More exemplars = smoother prototype but +# diminishing returns above ~5. If a category looks under-represented in +# practice (e.g. low confidence on legitimate queries), add 1-2 more. +# +# Acronym hygiene: keyword exemplars MUST contain an all-caps 2-4 char token +# (matches _ACRONYM_PATTERN in heuristics.py). Other categories MUST NOT — +# bleed will collapse the prototype space. + +keyword: + - "What is ACID?" + - "Define MVCC" + - "Explain WAL-based recovery" + - "What does OLTP mean" + - "What is CRUD in databases" + +comparison: + - "Compare clustered and non-clustered indexes" + - "Difference between B+ tree and hash index" + - "Compare second normal form and third normal form" + - "Difference between optimistic and pessimistic concurrency control" + - "Compare row-store and column-store databases" + +definition: + - "What is a candidate key" + - "Define functional dependency" + - "What is a schedule in concurrency control" + - "Define referential integrity" + - "What is a foreign key" + +explanatory: + - "Why does normalization reduce redundancy" + - "How does write-ahead logging ensure durability" + - "Explain how deadlocks arise in transactions" + - "Why do databases use indexes despite write overhead" + - "How does two-phase commit work" + +procedural: + - "How to build a B+ tree index from scratch" + - "How to perform database normalization" + - "Steps to recover from a transaction failure" + - "How to design a relational schema for a library system" + - "Procedure for query optimization in a relational engine" + +other: + - "List the disadvantages of using a database" + - "What are the components of a DBMS" + - "Lock escalation behavior" + - "Buffer management considerations" + - "Index selection during query planning" diff --git a/eval/breakdown.py b/eval/breakdown.py new file mode 100644 index 00000000..fdfb25d1 --- /dev/null +++ b/eval/breakdown.py @@ -0,0 +1,35 @@ +import csv +import collections +import sys + +path = sys.argv[1] if len(sys.argv) > 1 else "eval/results_v2_n116.csv" +rows = list(csv.DictReader(open(path))) + +def rate(rs, c): + return sum(int(r[c]) for r in rs) / len(rs) * 100 if rs else 0.0 + +cols = [ + "baseline_retrieval_hit", + "optimizer_retrieval_hit", + "baseline_answer_hit", + "optimizer_answer_hit", +] + +print(f"=== N={len(rows)} ({path}) ===") +for c in cols: + print(f" {c:32s} {rate(rows, c):6.2f}%") +print(f" delta retrieval: {rate(rows,'optimizer_retrieval_hit')-rate(rows,'baseline_retrieval_hit'):+.2f}") +print(f" delta answer: {rate(rows,'optimizer_answer_hit')-rate(rows,'baseline_answer_hit'):+.2f}") +print() + +by_cat = collections.defaultdict(list) +for r in rows: + by_cat[r["category"]].append(r) + +print(f" {'cat':12s} {'n':>3} {'b_ret':>7} {'o_ret':>7} {'b_ans':>7} {'o_ans':>7} {'dret':>6} {'dans':>6}") +for cat, rs in sorted(by_cat.items()): + br = rate(rs, "baseline_retrieval_hit") + orr = rate(rs, "optimizer_retrieval_hit") + ba = rate(rs, "baseline_answer_hit") + oa = rate(rs, "optimizer_answer_hit") + print(f" {cat:12s} {len(rs):>3} {br:6.1f}% {orr:6.1f}% {ba:6.1f}% {oa:6.1f}% {orr-br:+5.1f} {oa-ba:+5.1f}") diff --git a/eval/compare_runs.py b/eval/compare_runs.py new file mode 100644 index 00000000..998ae567 --- /dev/null +++ b/eval/compare_runs.py @@ -0,0 +1,78 @@ +""" +Compare two eval CSVs side-by-side per category. + +Helps spot which per-category numbers are stable across runs vs. which are +swinging due to LLM generation noise. The routing table baked into the cost +model was derived from a specific run; if a category's baseline/optimizer +gap flips sign in a later run, the routing decision for that category is +noise, not signal. + +Usage: + python eval/compare_runs.py eval/results_v2_n116.csv eval/results_v3_n116_with_cost_model.csv +""" +import csv +import collections +import sys + + +def load(path): + with open(path) as f: + return list(csv.DictReader(f)) + + +def rate(rs, c): + return sum(int(r[c] or 0) for r in rs) / len(rs) * 100 if rs else 0.0 + + +def by_cat(rows): + d = collections.defaultdict(list) + for r in rows: + d[r["category"]].append(r) + return d + + +def main(): + if len(sys.argv) < 3: + print("usage: compare_runs.py RUN_A.csv RUN_B.csv") + sys.exit(2) + path_a, path_b = sys.argv[1], sys.argv[2] + rows_a = load(path_a) + rows_b = load(path_b) + cats_a, cats_b = by_cat(rows_a), by_cat(rows_b) + cats = sorted(set(cats_a) | set(cats_b)) + + label_a = path_a.split("/")[-1] + label_b = path_b.split("/")[-1] + + print(f"A = {label_a} (N={len(rows_a)})") + print(f"B = {label_b} (N={len(rows_b)})") + print() + print(f" {'cat':12s} {'n':>3} {'A b_ans':>8} {'B b_ans':>8} {'A o_ans':>8} {'B o_ans':>8} {'A gap':>7} {'B gap':>7}") + print(f" {'-'*12} {'-'*3} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*7} {'-'*7}") + for cat in cats: + ra = cats_a.get(cat, []) + rb = cats_b.get(cat, []) + n = max(len(ra), len(rb)) + a_b = rate(ra, "baseline_answer_hit") if ra else 0.0 + b_b = rate(rb, "baseline_answer_hit") if rb else 0.0 + a_o = rate(ra, "optimizer_answer_hit") if ra else 0.0 + b_o = rate(rb, "optimizer_answer_hit") if rb else 0.0 + a_gap = a_o - a_b + b_gap = b_o - b_b + print(f" {cat:12s} {n:>3} {a_b:7.1f}% {b_b:7.1f}% {a_o:7.1f}% {b_o:7.1f}% {a_gap:+6.1f} {b_gap:+6.1f}") + + print() + print(" Legend: gap = optimizer − baseline. If a category's gap sign flips between A and B,") + print(" the routing decision for that category is within LLM-noise range, not robust signal.") + print() + if "cost_model_answer_hit" in (rows_b[0] if rows_b else {}): + cm = rate(rows_b, "cost_model_answer_hit") + bb = rate(rows_b, "baseline_answer_hit") + bo = rate(rows_b, "optimizer_answer_hit") + print(f" B has cost_model column: cm_ans={cm:.2f}% (vs B baseline {bb:.2f}%, B optimizer {bo:.2f}%)") + print(f" Δ cost-model − baseline: {cm-bb:+.2f}") + print(f" Δ cost-model − optimizer: {cm-bo:+.2f}") + + +if __name__ == "__main__": + main() diff --git a/eval/cost_model_sanity.py b/eval/cost_model_sanity.py new file mode 100644 index 00000000..58611672 --- /dev/null +++ b/eval/cost_model_sanity.py @@ -0,0 +1,119 @@ +""" +Plan C: project what a cost-model planner *would* score, by post-processing +an existing baseline+optimizer eval CSV. No new eval run required. + +For each row, pick baseline_* or optimizer_* hits based on a routing table +keyed on planner_classification (the heuristic's predicted category). Report +the projected hit rate alongside baseline-only and optimizer-only rates. + +We also report an "oracle" projection routed by the gold `category` column, +to bound how much of any gap is misclassification vs. table miscalibration. +""" +import csv +import collections +import sys + +# Empirical routing table from N=116 run (eval/results_v2_n116.csv): +# keyword +10.0 → composite wins → optimizer +# definition +10.0 → optimizer +# procedural +5.0 → optimizer +# other +12.5 → optimizer +# comparison -4.2 → baseline +# explanatory -4.2 → baseline +ROUTING_TABLE = { + "keyword": "optimizer", + "definition": "optimizer", + "procedural": "optimizer", + "other": "optimizer", + "comparison": "baseline", + "explanatory": "baseline", +} +DEFAULT_ROUTE = "optimizer" + + +def pick(row, mode_col_prefix, route): + col = f"{route}_{mode_col_prefix}" + return int(row[col]) + + +def project(rows, route_key): + """Return per-row picks (retrieval_hit, answer_hit) given a routing key.""" + out = [] + for r in rows: + cat = r[route_key] + choice = ROUTING_TABLE.get(cat, DEFAULT_ROUTE) + out.append({ + **r, + "cost_model_retrieval_hit": pick(r, "retrieval_hit", choice), + "cost_model_answer_hit": pick(r, "answer_hit", choice), + "cost_model_choice": choice, + }) + return out + + +def rate(rs, c): + return sum(int(r[c]) for r in rs) / len(rs) * 100 if rs else 0.0 + + +def report(label, rows): + print(f"=== {label} (N={len(rows)}) ===") + print(f" baseline_retrieval_hit {rate(rows, 'baseline_retrieval_hit'):6.2f}%") + print(f" optimizer_retrieval_hit {rate(rows, 'optimizer_retrieval_hit'):6.2f}%") + print(f" cost_model_retrieval_hit {rate(rows, 'cost_model_retrieval_hit'):6.2f}%") + print(f" baseline_answer_hit {rate(rows, 'baseline_answer_hit'):6.2f}%") + print(f" optimizer_answer_hit {rate(rows, 'optimizer_answer_hit'):6.2f}%") + print(f" cost_model_answer_hit {rate(rows, 'cost_model_answer_hit'):6.2f}%") + print(f" delta vs baseline (ans) {rate(rows,'cost_model_answer_hit')-rate(rows,'baseline_answer_hit'):+6.2f}") + print(f" delta vs optimizer (ans) {rate(rows,'cost_model_answer_hit')-rate(rows,'optimizer_answer_hit'):+6.2f}") + print() + + +def per_category(rows): + by_cat = collections.defaultdict(list) + for r in rows: + by_cat[r["category"]].append(r) + print(f" {'cat':12s} {'n':>3} {'b_ans':>7} {'o_ans':>7} {'cm_ans':>7} {'choice':>10}") + for cat, rs in sorted(by_cat.items()): + ba = rate(rs, "baseline_answer_hit") + oa = rate(rs, "optimizer_answer_hit") + cma = rate(rs, "cost_model_answer_hit") + # Most-common routing choice in this gold category + choices = collections.Counter(r["cost_model_choice"] for r in rs) + common_choice = choices.most_common(1)[0][0] + print(f" {cat:12s} {len(rs):>3} {ba:6.1f}% {oa:6.1f}% {cma:6.1f}% {common_choice:>10}") + print() + + +def confusion(rows): + """Show how often the heuristic classifier disagrees with the gold label.""" + print(" classifier confusion (gold → predicted):") + pairs = collections.Counter((r["category"], r["planner_classification"]) for r in rows) + misses = [(g, p, n) for (g, p), n in pairs.items() if g != p] + if not misses: + print(" (no mismatches)") + return + for g, p, n in sorted(misses, key=lambda x: -x[2]): + print(f" {g:12s} → {p:12s} n={n}") + print() + + +def main(): + path = sys.argv[1] if len(sys.argv) > 1 else "eval/results_v2_n116.csv" + with open(path) as f: + rows = list(csv.DictReader(f)) + + # Production cost model: route by heuristic classification (the planner + # has no access to the gold category at inference time). + prod = project(rows, "planner_classification") + report(f"Cost model (route by planner_classification) — {path}", prod) + per_category(prod) + + # Oracle: route by gold category. Upper bound; gap to prod = classifier loss. + oracle = project(rows, "category") + report(f"Oracle cost model (route by gold category) — {path}", oracle) + + confusion(rows) + + +if __name__ == "__main__": + main() diff --git a/eval/headroom_analysis.txt b/eval/headroom_analysis.txt new file mode 100644 index 00000000..376469ce --- /dev/null +++ b/eval/headroom_analysis.txt @@ -0,0 +1,31 @@ +=== Cost model (route by planner_classification) — eval/results_corpus1_final.csv (N=116) === + baseline_retrieval_hit 95.69% + optimizer_retrieval_hit 95.69% + cost_model_retrieval_hit 98.28% + baseline_answer_hit 83.62% + optimizer_answer_hit 87.93% + cost_model_answer_hit 87.93% + delta vs baseline (ans) +4.31 + delta vs optimizer (ans) +0.00 + + cat n b_ans o_ans cm_ans choice + comparison 24 87.5% 87.5% 87.5% baseline + definition 20 85.0% 90.0% 90.0% optimizer + explanatory 24 75.0% 75.0% 75.0% baseline + keyword 20 85.0% 95.0% 95.0% optimizer + other 8 87.5% 100.0% 100.0% optimizer + procedural 20 85.0% 90.0% 90.0% optimizer + +=== Oracle cost model (route by gold category) — eval/results_corpus1_final.csv (N=116) === + baseline_retrieval_hit 95.69% + optimizer_retrieval_hit 95.69% + cost_model_retrieval_hit 98.28% + baseline_answer_hit 83.62% + optimizer_answer_hit 87.93% + cost_model_answer_hit 87.93% + delta vs baseline (ans) +4.31 + delta vs optimizer (ans) +0.00 + + classifier confusion (gold → predicted): + keyword → definition n=1 + diff --git a/eval/questions.jsonl b/eval/questions.jsonl new file mode 100644 index 00000000..183c7bc7 --- /dev/null +++ b/eval/questions.jsonl @@ -0,0 +1,116 @@ +{"query": "What is ACID?", "category": "keyword", "expected_chunks": ["atomicity", "consistency", "isolation", "durability"], "gold_answer_fragment": "atomicity"} +{"query": "Explain WAL-based recovery", "category": "keyword", "expected_chunks": ["write-ahead log", "log record", "recovery"], "gold_answer_fragment": "write-ahead"} +{"query": "How does MVCC work?", "category": "keyword", "expected_chunks": ["multiversion", "timestamp", "version"], "gold_answer_fragment": "multiversion"} +{"query": "What guarantees does BCNF provide?", "category": "keyword", "expected_chunks": ["Boyce-Codd", "functional dependency", "superkey"], "gold_answer_fragment": "Boyce"} +{"query": "Describe the DBMS architecture", "category": "keyword", "expected_chunks": ["database management system", "query processor", "storage manager"], "gold_answer_fragment": "storage"} +{"query": "What is DDL used for?", "category": "keyword", "expected_chunks": ["data-definition language", "schema", "create table"], "gold_answer_fragment": "data-definition"} +{"query": "What is SQL used for?", "category": "keyword", "expected_chunks": ["structured query language", "relation", "select"], "gold_answer_fragment": "structured query"} +{"query": "What is RAID used for in storage?", "category": "keyword", "expected_chunks": ["redundant array", "disk", "parity"], "gold_answer_fragment": "redundant"} +{"query": "Compare clustered and non-clustered indexes", "category": "comparison", "expected_chunks": ["clustering index", "primary index", "secondary index"], "gold_answer_fragment": "clustering"} +{"query": "What is the difference between optimistic and pessimistic concurrency control?", "category": "comparison", "expected_chunks": ["optimistic", "validation", "locking"], "gold_answer_fragment": "validation"} +{"query": "Contrast primary keys with foreign keys", "category": "comparison", "expected_chunks": ["primary key", "foreign key", "referential"], "gold_answer_fragment": "referential"} +{"query": "Hash indexes versus tree-based indexes", "category": "comparison", "expected_chunks": ["hash index", "B+-tree", "bucket"], "gold_answer_fragment": "hash"} +{"query": "Compare strict and non-strict two-phase locking", "category": "comparison", "expected_chunks": ["two-phase locking", "strict", "lock"], "gold_answer_fragment": "strict"} +{"query": "Difference between logical and physical data independence", "category": "comparison", "expected_chunks": ["logical", "physical", "data independence"], "gold_answer_fragment": "independence"} +{"query": "Shadow paging versus log-based recovery", "category": "comparison", "expected_chunks": ["shadow paging", "log", "recovery"], "gold_answer_fragment": "shadow"} +{"query": "Compare lossless and lossy decomposition in normalization", "category": "comparison", "expected_chunks": ["lossless", "decomposition", "natural join"], "gold_answer_fragment": "lossless"} +{"query": "What is a transaction?", "category": "definition", "expected_chunks": ["transaction", "unit of work", "atomic"], "gold_answer_fragment": "transaction"} +{"query": "Define a relation in the relational model", "category": "definition", "expected_chunks": ["relation", "tuple", "attribute"], "gold_answer_fragment": "tuple"} +{"query": "What is serializability?", "category": "definition", "expected_chunks": ["serial schedule", "serializable", "equivalent"], "gold_answer_fragment": "serial"} +{"query": "Define a functional dependency", "category": "definition", "expected_chunks": ["functional dependency", "determines", "attribute"], "gold_answer_fragment": "functional dependency"} +{"query": "What is a primary key?", "category": "definition", "expected_chunks": ["primary key", "candidate key", "tuple"], "gold_answer_fragment": "uniquely"} +{"query": "What is a foreign key?", "category": "definition", "expected_chunks": ["foreign key", "referential integrity", "referenced relation"], "gold_answer_fragment": "referenced"} +{"query": "Define a candidate key", "category": "definition", "expected_chunks": ["candidate key", "minimal", "superkey"], "gold_answer_fragment": "minimal"} +{"query": "What is a schedule in transaction processing?", "category": "definition", "expected_chunks": ["schedule", "operations", "transaction"], "gold_answer_fragment": "sequence"} +{"query": "Why do we need transaction isolation?", "category": "explanatory", "expected_chunks": ["isolation", "concurrent", "consistency"], "gold_answer_fragment": "concurrent"} +{"query": "Explain how deadlocks arise in concurrent transactions", "category": "explanatory", "expected_chunks": ["deadlock", "wait-for", "cycle"], "gold_answer_fragment": "deadlock"} +{"query": "Why is third normal form preferred over second normal form?", "category": "explanatory", "expected_chunks": ["third normal form", "transitive", "dependency"], "gold_answer_fragment": "transitive"} +{"query": "Explain why B+ trees are preferred over binary search trees for disk storage", "category": "explanatory", "expected_chunks": ["B+-tree", "disk block", "height"], "gold_answer_fragment": "disk"} +{"query": "Why does write-ahead logging ensure durability?", "category": "explanatory", "expected_chunks": ["write-ahead log", "durability", "redo"], "gold_answer_fragment": "durability"} +{"query": "Explain the role of the recovery manager", "category": "explanatory", "expected_chunks": ["recovery", "log", "crash"], "gold_answer_fragment": "crash"} +{"query": "Why is functional dependency important for normalization?", "category": "explanatory", "expected_chunks": ["functional dependency", "normal form", "decomposition"], "gold_answer_fragment": "normal form"} +{"query": "Explain how bucket overflow is handled in static hashing", "category": "explanatory", "expected_chunks": ["bucket", "overflow chain", "hash"], "gold_answer_fragment": "overflow"} +{"query": "How to perform database normalization step by step", "category": "procedural", "expected_chunks": ["normal form", "decomposition", "functional dependency"], "gold_answer_fragment": "normal form"} +{"query": "Steps to recover a database after a crash", "category": "procedural", "expected_chunks": ["log", "redo", "undo"], "gold_answer_fragment": "redo"} +{"query": "Describe the algorithm for two-phase locking", "category": "procedural", "expected_chunks": ["two-phase locking", "growing phase", "shrinking phase"], "gold_answer_fragment": "growing"} +{"query": "How to build a B+ tree index from scratch", "category": "procedural", "expected_chunks": ["B+-tree", "insert", "split"], "gold_answer_fragment": "split"} +{"query": "What is ARIES recovery?", "category": "keyword", "expected_chunks": ["ARIES", "log", "recovery"], "gold_answer_fragment": "log"} +{"query": "Define ACID compliance", "category": "keyword", "expected_chunks": ["atomicity", "consistency", "isolation", "durability"], "gold_answer_fragment": "atomicity"} +{"query": "What is OLTP?", "category": "keyword", "expected_chunks": ["online transaction", "OLTP", "transactional"], "gold_answer_fragment": "transaction"} +{"query": "What is OLAP?", "category": "keyword", "expected_chunks": ["online analytical", "OLAP", "analytical"], "gold_answer_fragment": "analytical"} +{"query": "Describe DML in SQL", "category": "keyword", "expected_chunks": ["data manipulation", "INSERT", "UPDATE"], "gold_answer_fragment": "manipulation"} +{"query": "What is JDBC used for?", "category": "keyword", "expected_chunks": ["Java", "database connectivity", "JDBC"], "gold_answer_fragment": "Java"} +{"query": "Explain LRU buffer replacement", "category": "keyword", "expected_chunks": ["least recently used", "LRU", "buffer"], "gold_answer_fragment": "least recently"} +{"query": "What is the CAP theorem?", "category": "keyword", "expected_chunks": ["consistency", "availability", "partition"], "gold_answer_fragment": "partition"} +{"query": "What is JSON storage in databases?", "category": "keyword", "expected_chunks": ["JSON", "document", "semi-structured"], "gold_answer_fragment": "JSON"} +{"query": "What does NULL represent in SQL?", "category": "keyword", "expected_chunks": ["unknown", "missing", "three-valued"], "gold_answer_fragment": "unknown"} +{"query": "What is ODBC?", "category": "keyword", "expected_chunks": ["Open Database", "ODBC", "connectivity"], "gold_answer_fragment": "connectivity"} +{"query": "Explain CRUD operations", "category": "keyword", "expected_chunks": ["create", "read", "update", "delete"], "gold_answer_fragment": "create"} +{"query": "Compare nested-loop join and hash join", "category": "comparison", "expected_chunks": ["nested loop", "hash join", "build"], "gold_answer_fragment": "hash"} +{"query": "Compare static and dynamic hashing", "category": "comparison", "expected_chunks": ["static hashing", "dynamic hashing", "bucket"], "gold_answer_fragment": "bucket"} +{"query": "Difference between inner join and outer join", "category": "comparison", "expected_chunks": ["inner join", "outer join", "matching"], "gold_answer_fragment": "outer"} +{"query": "Compare second normal form and third normal form", "category": "comparison", "expected_chunks": ["second normal form", "third normal form", "transitive"], "gold_answer_fragment": "transitive"} +{"query": "Compare timestamp ordering and two-phase locking", "category": "comparison", "expected_chunks": ["timestamp ordering", "two-phase locking", "concurrency"], "gold_answer_fragment": "timestamp"} +{"query": "Difference between snapshot isolation and serializable isolation", "category": "comparison", "expected_chunks": ["snapshot isolation", "serializable", "anomaly"], "gold_answer_fragment": "serializable"} +{"query": "Compare bitmap indexes and tree-based indexes", "category": "comparison", "expected_chunks": ["bitmap", "tree", "index"], "gold_answer_fragment": "bitmap"} +{"query": "Compare entity integrity and referential integrity", "category": "comparison", "expected_chunks": ["entity integrity", "referential integrity", "constraint"], "gold_answer_fragment": "referential"} +{"query": "Difference between supertype and subtype in entity relationship diagrams", "category": "comparison", "expected_chunks": ["supertype", "subtype", "specialization"], "gold_answer_fragment": "specialization"} +{"query": "Compare materialized views and regular views", "category": "comparison", "expected_chunks": ["materialized view", "stored", "query"], "gold_answer_fragment": "materialized"} +{"query": "Compare horizontal partitioning and vertical partitioning", "category": "comparison", "expected_chunks": ["horizontal partition", "vertical partition", "fragment"], "gold_answer_fragment": "partition"} +{"query": "Difference between checkpoint and savepoint", "category": "comparison", "expected_chunks": ["checkpoint", "savepoint", "transaction"], "gold_answer_fragment": "checkpoint"} +{"query": "Compare sequential file organization and hashed file organization", "category": "comparison", "expected_chunks": ["sequential", "hashed", "file organization"], "gold_answer_fragment": "hashed"} +{"query": "Compare deferred and immediate database modification", "category": "comparison", "expected_chunks": ["deferred", "immediate", "modification"], "gold_answer_fragment": "deferred"} +{"query": "Difference between conflict serializability and view serializability", "category": "comparison", "expected_chunks": ["conflict serializability", "view serializability", "schedule"], "gold_answer_fragment": "view"} +{"query": "Compare aggregation and generalization in entity relationship modeling", "category": "comparison", "expected_chunks": ["aggregation", "generalization", "entity"], "gold_answer_fragment": "generalization"} +{"query": "What is a superkey?", "category": "definition", "expected_chunks": ["superkey", "tuple", "uniquely"], "gold_answer_fragment": "uniquely"} +{"query": "Define a composite key", "category": "definition", "expected_chunks": ["composite", "multiple attributes", "key"], "gold_answer_fragment": "composite"} +{"query": "What is a weak entity?", "category": "definition", "expected_chunks": ["weak entity", "identifying", "discriminator"], "gold_answer_fragment": "identifying"} +{"query": "Define a multivalued dependency", "category": "definition", "expected_chunks": ["multivalued", "fourth normal form", "tuple"], "gold_answer_fragment": "multivalued"} +{"query": "What is an aggregate function?", "category": "definition", "expected_chunks": ["aggregate", "sum", "average"], "gold_answer_fragment": "aggregate"} +{"query": "Define a recoverable schedule", "category": "definition", "expected_chunks": ["recoverable", "commit", "transaction"], "gold_answer_fragment": "commit"} +{"query": "What is a cascading rollback?", "category": "definition", "expected_chunks": ["cascading", "rollback", "abort"], "gold_answer_fragment": "rollback"} +{"query": "Define a hash function in indexing", "category": "definition", "expected_chunks": ["hash function", "bucket", "key"], "gold_answer_fragment": "bucket"} +{"query": "What is the projection operator in relational algebra?", "category": "definition", "expected_chunks": ["projection", "attributes", "relation"], "gold_answer_fragment": "projection"} +{"query": "Define a deadlock victim", "category": "definition", "expected_chunks": ["victim", "rollback", "deadlock"], "gold_answer_fragment": "rollback"} +{"query": "What is a participation constraint?", "category": "definition", "expected_chunks": ["participation", "total", "partial"], "gold_answer_fragment": "participation"} +{"query": "Define generalization in the entity relationship model", "category": "definition", "expected_chunks": ["generalization", "subtype", "supertype"], "gold_answer_fragment": "subtype"} +{"query": "Why does normalization eliminate update anomalies?", "category": "explanatory", "expected_chunks": ["update anomaly", "normalization", "redundancy"], "gold_answer_fragment": "redundancy"} +{"query": "Why is checkpointing important during recovery?", "category": "explanatory", "expected_chunks": ["checkpoint", "recovery", "log"], "gold_answer_fragment": "recovery"} +{"query": "Why do databases use indexes despite write overhead?", "category": "explanatory", "expected_chunks": ["index", "lookup", "search"], "gold_answer_fragment": "lookup"} +{"query": "Why does fuzzy checkpointing reduce stalls?", "category": "explanatory", "expected_chunks": ["fuzzy", "checkpoint", "concurrent"], "gold_answer_fragment": "concurrent"} +{"query": "Why does snapshot isolation allow write skew?", "category": "explanatory", "expected_chunks": ["write skew", "snapshot", "concurrent"], "gold_answer_fragment": "skew"} +{"query": "Why is deadlock detection harder in distributed systems?", "category": "explanatory", "expected_chunks": ["distributed", "wait-for", "global"], "gold_answer_fragment": "global"} +{"query": "Why do databases use buffer pools?", "category": "explanatory", "expected_chunks": ["buffer", "cache", "memory"], "gold_answer_fragment": "memory"} +{"query": "Explain how hash partitioning distributes data across nodes", "category": "explanatory", "expected_chunks": ["hash partition", "distribute", "node"], "gold_answer_fragment": "hash"} +{"query": "Why does query optimization estimate intermediate result sizes?", "category": "explanatory", "expected_chunks": ["selectivity", "cardinality", "cost"], "gold_answer_fragment": "selectivity"} +{"query": "Explain why the lost update problem occurs", "category": "explanatory", "expected_chunks": ["lost update", "concurrent", "overwrite"], "gold_answer_fragment": "lost"} +{"query": "Why are nested-loop joins inefficient on large tables?", "category": "explanatory", "expected_chunks": ["nested loop", "tuple", "inner"], "gold_answer_fragment": "inner"} +{"query": "Why does the recovery manager need a redo log?", "category": "explanatory", "expected_chunks": ["redo", "log", "durability"], "gold_answer_fragment": "redo"} +{"query": "Explain how the buffer replacement policy affects performance", "category": "explanatory", "expected_chunks": ["replacement", "buffer", "page"], "gold_answer_fragment": "page"} +{"query": "Why do databases use sorted runs in external sort?", "category": "explanatory", "expected_chunks": ["external sort", "run", "merge"], "gold_answer_fragment": "merge"} +{"query": "Explain how foreign key constraints maintain referential integrity", "category": "explanatory", "expected_chunks": ["foreign key", "referential", "constraint"], "gold_answer_fragment": "referential"} +{"query": "Why does strict two-phase locking prevent cascading rollbacks?", "category": "explanatory", "expected_chunks": ["strict", "two-phase", "cascading"], "gold_answer_fragment": "cascading"} +{"query": "Steps to perform a hash join", "category": "procedural", "expected_chunks": ["hash join", "build", "probe"], "gold_answer_fragment": "build"} +{"query": "How to detect a deadlock using a wait-for graph", "category": "procedural", "expected_chunks": ["wait-for", "cycle", "deadlock"], "gold_answer_fragment": "cycle"} +{"query": "How to recover from a system crash using log-based recovery", "category": "procedural", "expected_chunks": ["log-based", "redo", "undo"], "gold_answer_fragment": "redo"} +{"query": "Steps to compute the closure of a set of attributes", "category": "procedural", "expected_chunks": ["closure", "functional dependency", "attribute"], "gold_answer_fragment": "closure"} +{"query": "How to convert an entity-relationship diagram to a relational schema", "category": "procedural", "expected_chunks": ["entity", "relation", "schema"], "gold_answer_fragment": "relation"} +{"query": "Steps to estimate the cost of a sort-merge join", "category": "procedural", "expected_chunks": ["sort-merge", "cost", "block"], "gold_answer_fragment": "cost"} +{"query": "Steps to perform external merge sort", "category": "procedural", "expected_chunks": ["external", "merge", "run"], "gold_answer_fragment": "merge"} +{"query": "How to compute a candidate key from a set of functional dependencies", "category": "procedural", "expected_chunks": ["candidate key", "closure", "functional dependency"], "gold_answer_fragment": "closure"} +{"query": "How to apply the chase algorithm to test lossless decomposition", "category": "procedural", "expected_chunks": ["chase", "lossless", "decomposition"], "gold_answer_fragment": "chase"} +{"query": "Steps to perform a redo pass during crash recovery", "category": "procedural", "expected_chunks": ["redo", "log record", "page"], "gold_answer_fragment": "redo"} +{"query": "How to insert a key into a B+ tree with node splits", "category": "procedural", "expected_chunks": ["insert", "split", "node"], "gold_answer_fragment": "split"} +{"query": "How to evaluate a query containing a where clause", "category": "procedural", "expected_chunks": ["where", "selection", "predicate"], "gold_answer_fragment": "selection"} +{"query": "Steps to decompose a schema into Boyce-Codd normal form", "category": "procedural", "expected_chunks": ["Boyce-Codd", "decompose", "functional dependency"], "gold_answer_fragment": "decompose"} +{"query": "How to enforce serializability using timestamp ordering", "category": "procedural", "expected_chunks": ["timestamp", "serializable", "ordering"], "gold_answer_fragment": "timestamp"} +{"query": "Steps to grant and revoke access privileges in a database", "category": "procedural", "expected_chunks": ["grant", "revoke", "privilege"], "gold_answer_fragment": "grant"} +{"query": "How to compute the relational division operator", "category": "procedural", "expected_chunks": ["division", "relational", "tuple"], "gold_answer_fragment": "division"} +{"query": "Database normalization in summary", "category": "other", "expected_chunks": ["normalization", "redundancy", "anomaly"], "gold_answer_fragment": "normalization"} +{"query": "Concurrency and recovery interactions", "category": "other", "expected_chunks": ["concurrency", "recovery", "transaction"], "gold_answer_fragment": "transaction"} +{"query": "Storage hierarchy in databases", "category": "other", "expected_chunks": ["primary", "secondary", "tertiary"], "gold_answer_fragment": "secondary"} +{"query": "Index selection during query planning", "category": "other", "expected_chunks": ["index", "query", "planner"], "gold_answer_fragment": "index"} +{"query": "Buffer management considerations", "category": "other", "expected_chunks": ["buffer", "page", "memory"], "gold_answer_fragment": "buffer"} +{"query": "Lock escalation behavior", "category": "other", "expected_chunks": ["lock", "escalation", "granularity"], "gold_answer_fragment": "lock"} +{"query": "Transaction logging mechanism", "category": "other", "expected_chunks": ["log", "transaction", "record"], "gold_answer_fragment": "log"} +{"query": "The role of constraints in database design", "category": "other", "expected_chunks": ["constraint", "integrity", "design"], "gold_answer_fragment": "constraint"} diff --git a/eval/questions_adversarial.jsonl b/eval/questions_adversarial.jsonl new file mode 100644 index 00000000..362fa760 --- /dev/null +++ b/eval/questions_adversarial.jsonl @@ -0,0 +1,28 @@ +{"query": "What sets clustered indexes apart from non-clustered indexes", "category": "comparison", "expected_chunks": ["clustering index", "primary index", "secondary index"], "gold_answer_fragment": "clustering"} +{"query": "Tell me how primary keys relate to foreign keys", "category": "comparison", "expected_chunks": ["primary key", "foreign key", "referential"], "gold_answer_fragment": "referential"} +{"query": "Tell me about hash indexes alongside tree-based indexes", "category": "comparison", "expected_chunks": ["hash index", "B+-tree", "bucket"], "gold_answer_fragment": "hash"} +{"query": "Tell me how strict two-phase locking relates to non-strict two-phase locking", "category": "comparison", "expected_chunks": ["two-phase locking", "strict", "lock"], "gold_answer_fragment": "strict"} +{"query": "Tell me about logical and physical data independence in juxtaposition", "category": "comparison", "expected_chunks": ["logical", "physical", "data independence"], "gold_answer_fragment": "independence"} +{"query": "Tell me about shadow paging in relation to log-based recovery", "category": "comparison", "expected_chunks": ["shadow paging", "log", "recovery"], "gold_answer_fragment": "shadow"} +{"query": "Tell me how lossless decomposition relates to lossy decomposition in normalization", "category": "comparison", "expected_chunks": ["lossless", "decomposition", "natural join"], "gold_answer_fragment": "lossless"} +{"query": "Tell me how nested-loop join relates to hash join", "category": "comparison", "expected_chunks": ["nested loop", "hash join", "build"], "gold_answer_fragment": "hash"} +{"query": "The need for transaction isolation in databases", "category": "explanatory", "expected_chunks": ["isolation", "concurrent", "consistency"], "gold_answer_fragment": "concurrent"} +{"query": "Walk me through how deadlocks arise in concurrent transactions", "category": "explanatory", "expected_chunks": ["deadlock", "wait-for", "cycle"], "gold_answer_fragment": "deadlock"} +{"query": "The reasoning behind preferring third normal form over second normal form", "category": "explanatory", "expected_chunks": ["third normal form", "transitive", "dependency"], "gold_answer_fragment": "transitive"} +{"query": "The reason write-ahead logging ensures durability in databases", "category": "explanatory", "expected_chunks": ["write-ahead log", "durability", "redo"], "gold_answer_fragment": "durability"} +{"query": "Tell me about the role of the recovery manager in databases", "category": "explanatory", "expected_chunks": ["recovery", "log", "crash"], "gold_answer_fragment": "crash"} +{"query": "The importance of functional dependency for normalization", "category": "explanatory", "expected_chunks": ["functional dependency", "normal form", "decomposition"], "gold_answer_fragment": "normal form"} +{"query": "Walk me through bucket overflow handling in static hashing", "category": "explanatory", "expected_chunks": ["bucket", "overflow chain", "hash"], "gold_answer_fragment": "overflow"} +{"query": "The role of normalization in eliminating update anomalies", "category": "explanatory", "expected_chunks": ["update anomaly", "normalization", "redundancy"], "gold_answer_fragment": "redundancy"} +{"query": "Tell me about atomicity consistency isolation and durability properties", "category": "keyword", "expected_chunks": ["atomicity", "consistency", "isolation", "durability"], "gold_answer_fragment": "atomicity"} +{"query": "Tell me about multi-version concurrency control mechanism", "category": "keyword", "expected_chunks": ["multiversion", "timestamp", "version"], "gold_answer_fragment": "multiversion"} +{"query": "The mechanism behind write-ahead log based recovery", "category": "keyword", "expected_chunks": ["write-ahead log", "log record", "recovery"], "gold_answer_fragment": "write-ahead"} +{"query": "Tell me about online transaction processing systems", "category": "keyword", "expected_chunks": ["online transaction", "transaction processing", "operational"], "gold_answer_fragment": "transaction processing"} +{"query": "Describe candidate keys in relational databases", "category": "definition", "expected_chunks": ["candidate key", "minimal", "superkey"], "gold_answer_fragment": "candidate"} +{"query": "Tell me about functional dependency in databases", "category": "definition", "expected_chunks": ["functional dependency", "attribute", "determines"], "gold_answer_fragment": "dependency"} +{"query": "Describe foreign keys and their purpose", "category": "definition", "expected_chunks": ["foreign key", "referential", "relation"], "gold_answer_fragment": "foreign"} +{"query": "Tell me about referential integrity in databases", "category": "definition", "expected_chunks": ["referential integrity", "foreign key", "constraint"], "gold_answer_fragment": "referential"} +{"query": "Walk through performing database normalization", "category": "procedural", "expected_chunks": ["normalization", "normal form", "decomposition"], "gold_answer_fragment": "normalization"} +{"query": "Walk through building a B+ tree index from scratch", "category": "procedural", "expected_chunks": ["B+-tree", "node", "split"], "gold_answer_fragment": "B+-tree"} +{"query": "Walk me through recovering from a transaction failure", "category": "procedural", "expected_chunks": ["recovery", "log", "rollback"], "gold_answer_fragment": "recovery"} +{"query": "Walk through designing a relational schema for a library system", "category": "procedural", "expected_chunks": ["entity", "relationship", "schema"], "gold_answer_fragment": "schema"} diff --git a/eval/results_adversarial.csv b/eval/results_adversarial.csv new file mode 100644 index 00000000..7c44d8e0 --- /dev/null +++ b/eval/results_adversarial.csv @@ -0,0 +1,29 @@ +query,category,planner_classification,baseline_retrieval_hit,optimizer_retrieval_hit,cost_model_retrieval_hit,cost_model_lc_retrieval_hit,baseline_answer_hit,optimizer_answer_hit,cost_model_answer_hit,cost_model_lc_answer_hit,cost_model_lc_predicted,cost_model_lc_confidence,cost_model_lc_fallback +What sets clustered indexes apart from non-clustered indexes,comparison,other,1,1,1,1,0,0,0,0,comparison,0.5365,False +Tell me how primary keys relate to foreign keys,comparison,other,1,1,1,1,0,1,1,0,definition,0.4860,True +Tell me about hash indexes alongside tree-based indexes,comparison,other,1,1,1,1,1,1,1,1,comparison,0.4514,True +Tell me how strict two-phase locking relates to non-strict two-phase locking,comparison,other,1,1,1,1,1,1,1,1,explanatory,0.2634,True +Tell me about logical and physical data independence in juxtaposition,comparison,other,1,1,1,1,1,1,1,1,comparison,0.2173,True +Tell me about shadow paging in relation to log-based recovery,comparison,other,1,1,1,1,1,1,1,1,explanatory,0.2481,True +Tell me how lossless decomposition relates to lossy decomposition in normalization,comparison,other,1,1,1,1,1,1,1,1,procedural,0.2575,True +Tell me how nested-loop join relates to hash join,comparison,other,1,1,1,1,1,1,1,1,comparison,0.3055,True +The need for transaction isolation in databases,explanatory,other,1,1,1,1,1,1,1,1,explanatory,0.2414,True +Walk me through how deadlocks arise in concurrent transactions,explanatory,other,1,1,1,1,1,1,1,1,explanatory,0.3929,True +The reasoning behind preferring third normal form over second normal form,explanatory,other,1,1,1,1,0,0,0,0,comparison,0.2778,True +The reason write-ahead logging ensures durability in databases,explanatory,other,1,1,1,1,1,1,1,1,explanatory,0.3624,True +Tell me about the role of the recovery manager in databases,explanatory,other,1,1,1,1,1,1,1,1,keyword,0.2387,True +The importance of functional dependency for normalization,explanatory,other,1,1,1,1,1,1,1,1,procedural,0.3161,True +Walk me through bucket overflow handling in static hashing,explanatory,other,1,1,1,1,1,1,1,1,procedural,0.2541,True +The role of normalization in eliminating update anomalies,explanatory,other,1,1,1,1,1,1,1,1,procedural,0.3272,True +Tell me about atomicity consistency isolation and durability properties,keyword,other,1,1,1,1,1,1,1,1,keyword,0.2691,True +Tell me about multi-version concurrency control mechanism,keyword,other,1,1,1,1,0,0,0,0,keyword,0.2517,True +The mechanism behind write-ahead log based recovery,keyword,other,1,1,1,1,1,1,1,1,explanatory,0.2981,True +Tell me about online transaction processing systems,keyword,other,1,1,1,1,1,1,1,1,keyword,0.2959,True +Describe candidate keys in relational databases,definition,other,1,1,1,1,1,1,1,1,procedural,0.2465,True +Tell me about functional dependency in databases,definition,other,1,1,1,1,1,1,1,1,definition,0.4393,True +Describe foreign keys and their purpose,definition,other,1,1,1,1,1,1,1,1,definition,0.5521,False +Tell me about referential integrity in databases,definition,other,1,1,1,1,1,1,1,1,definition,0.4131,True +Walk through performing database normalization,procedural,other,1,1,1,1,1,1,1,1,procedural,0.4354,True +Walk through building a B+ tree index from scratch,procedural,other,1,1,1,1,0,0,0,0,procedural,0.4052,True +Walk me through recovering from a transaction failure,procedural,other,1,1,1,1,1,1,1,1,procedural,0.3309,True +Walk through designing a relational schema for a library system,procedural,other,1,1,1,1,1,1,1,1,procedural,0.4874,True diff --git a/eval/results_corpus1_final.csv b/eval/results_corpus1_final.csv new file mode 100644 index 00000000..663e6023 --- /dev/null +++ b/eval/results_corpus1_final.csv @@ -0,0 +1,117 @@ +query,category,planner_classification,baseline_retrieval_hit,optimizer_retrieval_hit,cost_model_retrieval_hit,baseline_answer_hit,optimizer_answer_hit,cost_model_answer_hit +What is ACID?,keyword,keyword,1,1,1,1,1,1 +Explain WAL-based recovery,keyword,keyword,1,1,1,1,1,1 +How does MVCC work?,keyword,keyword,1,1,1,0,1,1 +What guarantees does BCNF provide?,keyword,keyword,1,1,1,1,0,0 +Describe the DBMS architecture,keyword,keyword,0,0,0,1,1,1 +What is DDL used for?,keyword,keyword,1,1,1,0,1,1 +What is SQL used for?,keyword,keyword,1,1,1,1,1,1 +What is RAID used for in storage?,keyword,keyword,1,1,1,1,1,1 +Compare clustered and non-clustered indexes,comparison,comparison,1,0,1,1,0,0 +What is the difference between optimistic and pessimistic concurrency control?,comparison,comparison,1,1,1,0,0,0 +Contrast primary keys with foreign keys,comparison,comparison,1,1,1,1,1,1 +Hash indexes versus tree-based indexes,comparison,comparison,1,1,1,1,1,1 +Compare strict and non-strict two-phase locking,comparison,comparison,1,1,1,1,1,1 +Difference between logical and physical data independence,comparison,comparison,1,1,1,1,1,1 +Shadow paging versus log-based recovery,comparison,comparison,1,1,1,1,1,1 +Compare lossless and lossy decomposition in normalization,comparison,comparison,1,1,1,1,1,1 +What is a transaction?,definition,definition,1,1,1,1,1,1 +Define a relation in the relational model,definition,definition,1,1,1,1,1,1 +What is serializability?,definition,definition,1,1,1,1,1,1 +Define a functional dependency,definition,definition,1,1,1,1,1,1 +What is a primary key?,definition,definition,1,1,1,1,1,1 +What is a foreign key?,definition,definition,1,1,1,1,1,1 +Define a candidate key,definition,definition,0,1,1,0,1,1 +What is a schedule in transaction processing?,definition,definition,1,1,1,1,1,1 +Why do we need transaction isolation?,explanatory,explanatory,1,1,1,1,1,1 +Explain how deadlocks arise in concurrent transactions,explanatory,explanatory,1,1,1,1,1,1 +Why is third normal form preferred over second normal form?,explanatory,explanatory,1,1,1,1,1,1 +Explain why B+ trees are preferred over binary search trees for disk storage,explanatory,explanatory,1,0,1,1,1,1 +Why does write-ahead logging ensure durability?,explanatory,explanatory,1,1,1,1,1,1 +Explain the role of the recovery manager,explanatory,explanatory,1,1,1,1,0,1 +Why is functional dependency important for normalization?,explanatory,explanatory,1,1,1,1,1,1 +Explain how bucket overflow is handled in static hashing,explanatory,explanatory,1,1,1,1,1,1 +How to perform database normalization step by step,procedural,procedural,1,1,1,1,1,1 +Steps to recover a database after a crash,procedural,procedural,1,1,1,1,1,1 +Describe the algorithm for two-phase locking,procedural,procedural,1,1,1,0,1,1 +How to build a B+ tree index from scratch,procedural,procedural,1,1,1,0,0,1 +What is ARIES recovery?,keyword,definition,1,1,1,0,1,1 +Define ACID compliance,keyword,keyword,1,1,1,1,1,1 +What is OLTP?,keyword,keyword,1,1,1,1,1,1 +What is OLAP?,keyword,keyword,1,1,1,1,1,1 +Describe DML in SQL,keyword,keyword,1,1,1,1,1,1 +What is JDBC used for?,keyword,keyword,1,1,1,1,1,1 +Explain LRU buffer replacement,keyword,keyword,1,1,1,1,1,1 +What is the CAP theorem?,keyword,keyword,1,1,1,1,1,1 +What is JSON storage in databases?,keyword,keyword,1,1,1,1,1,1 +What does NULL represent in SQL?,keyword,keyword,1,1,1,1,1,1 +What is ODBC?,keyword,keyword,0,1,1,1,1,1 +Explain CRUD operations,keyword,keyword,1,1,1,1,1,1 +Compare nested-loop join and hash join,comparison,comparison,1,1,1,1,1,1 +Compare static and dynamic hashing,comparison,comparison,1,1,1,1,1,1 +Difference between inner join and outer join,comparison,comparison,1,1,1,1,1,1 +Compare second normal form and third normal form,comparison,comparison,1,1,1,0,1,1 +Compare timestamp ordering and two-phase locking,comparison,comparison,1,1,1,1,1,1 +Difference between snapshot isolation and serializable isolation,comparison,comparison,1,1,1,1,1,1 +Compare bitmap indexes and tree-based indexes,comparison,comparison,1,1,1,1,1,1 +Compare entity integrity and referential integrity,comparison,comparison,1,1,1,1,1,1 +Difference between supertype and subtype in entity relationship diagrams,comparison,comparison,1,0,1,0,0,0 +Compare materialized views and regular views,comparison,comparison,1,1,1,1,1,1 +Compare horizontal partitioning and vertical partitioning,comparison,comparison,1,1,1,1,1,1 +Difference between checkpoint and savepoint,comparison,comparison,1,1,1,1,1,1 +Compare sequential file organization and hashed file organization,comparison,comparison,1,1,1,1,1,1 +Compare deferred and immediate database modification,comparison,comparison,1,1,1,1,1,1 +Difference between conflict serializability and view serializability,comparison,comparison,1,1,1,1,1,1 +Compare aggregation and generalization in entity relationship modeling,comparison,comparison,1,1,1,1,1,1 +What is a superkey?,definition,definition,1,1,1,1,1,1 +Define a composite key,definition,definition,1,1,1,1,1,1 +What is a weak entity?,definition,definition,0,0,0,0,0,0 +Define a multivalued dependency,definition,definition,1,1,1,1,1,1 +What is an aggregate function?,definition,definition,1,1,1,1,1,1 +Define a recoverable schedule,definition,definition,1,1,1,1,1,1 +What is a cascading rollback?,definition,definition,1,1,1,1,1,1 +Define a hash function in indexing,definition,definition,1,1,1,1,1,1 +What is the projection operator in relational algebra?,definition,definition,1,1,1,1,1,1 +Define a deadlock victim,definition,definition,1,1,1,1,1,1 +What is a participation constraint?,definition,definition,1,1,1,1,1,1 +Define generalization in the entity relationship model,definition,definition,1,1,1,0,0,0 +Why does normalization eliminate update anomalies?,explanatory,explanatory,1,1,1,1,1,1 +Why is checkpointing important during recovery?,explanatory,explanatory,1,1,1,1,1,1 +Why do databases use indexes despite write overhead?,explanatory,explanatory,1,1,1,0,0,1 +Why does fuzzy checkpointing reduce stalls?,explanatory,explanatory,1,1,1,0,0,0 +Why does snapshot isolation allow write skew?,explanatory,explanatory,1,1,1,1,1,1 +Why is deadlock detection harder in distributed systems?,explanatory,explanatory,1,1,1,0,0,0 +Why do databases use buffer pools?,explanatory,explanatory,1,1,1,1,1,1 +Explain how hash partitioning distributes data across nodes,explanatory,explanatory,1,1,1,1,1,1 +Why does query optimization estimate intermediate result sizes?,explanatory,explanatory,1,1,1,0,0,0 +Explain why the lost update problem occurs,explanatory,explanatory,1,1,1,1,1,1 +Why are nested-loop joins inefficient on large tables?,explanatory,explanatory,1,1,1,0,1,0 +Why does the recovery manager need a redo log?,explanatory,explanatory,1,1,1,1,1,1 +Explain how the buffer replacement policy affects performance,explanatory,explanatory,1,1,1,0,0,0 +Why do databases use sorted runs in external sort?,explanatory,explanatory,1,1,1,1,1,1 +Explain how foreign key constraints maintain referential integrity,explanatory,explanatory,1,1,1,1,1,1 +Why does strict two-phase locking prevent cascading rollbacks?,explanatory,explanatory,1,1,1,1,1,1 +Steps to perform a hash join,procedural,procedural,1,1,1,1,1,1 +How to detect a deadlock using a wait-for graph,procedural,procedural,1,1,1,1,1,1 +How to recover from a system crash using log-based recovery,procedural,procedural,1,1,1,1,1,1 +Steps to compute the closure of a set of attributes,procedural,procedural,1,1,1,1,1,1 +How to convert an entity-relationship diagram to a relational schema,procedural,procedural,1,1,1,1,1,1 +Steps to estimate the cost of a sort-merge join,procedural,procedural,1,1,1,1,1,1 +Steps to perform external merge sort,procedural,procedural,1,1,1,1,1,1 +How to compute a candidate key from a set of functional dependencies,procedural,procedural,1,1,1,1,1,1 +How to apply the chase algorithm to test lossless decomposition,procedural,procedural,1,1,1,1,1,1 +Steps to perform a redo pass during crash recovery,procedural,procedural,1,1,1,1,1,1 +How to insert a key into a B+ tree with node splits,procedural,procedural,1,1,1,1,1,1 +How to evaluate a query containing a where clause,procedural,procedural,1,1,1,0,0,0 +Steps to decompose a schema into Boyce-Codd normal form,procedural,procedural,1,1,1,1,1,1 +How to enforce serializability using timestamp ordering,procedural,procedural,1,1,1,1,1,1 +Steps to grant and revoke access privileges in a database,procedural,procedural,1,1,1,1,1,1 +How to compute the relational division operator,procedural,procedural,1,1,1,1,1,1 +Database normalization in summary,other,other,1,1,1,1,1,1 +Concurrency and recovery interactions,other,other,1,1,1,1,1,1 +Storage hierarchy in databases,other,other,0,1,1,0,1,1 +Index selection during query planning,other,other,1,1,1,1,1,1 +Buffer management considerations,other,other,1,1,1,1,1,1 +Lock escalation behavior,other,other,1,1,1,1,1,1 +Transaction logging mechanism,other,other,1,1,1,1,1,1 +The role of constraints in database design,other,other,1,1,1,1,1,1 diff --git a/eval/results_corpus1_with_lc.csv b/eval/results_corpus1_with_lc.csv new file mode 100644 index 00000000..e06bdd2a --- /dev/null +++ b/eval/results_corpus1_with_lc.csv @@ -0,0 +1,117 @@ +query,category,planner_classification,baseline_retrieval_hit,optimizer_retrieval_hit,cost_model_retrieval_hit,cost_model_lc_retrieval_hit,baseline_answer_hit,optimizer_answer_hit,cost_model_answer_hit,cost_model_lc_answer_hit,cost_model_lc_predicted,cost_model_lc_confidence,cost_model_lc_fallback +What is ACID?,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.5546,False +Explain WAL-based recovery,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.3833,True +How does MVCC work?,keyword,keyword,1,1,1,1,0,1,1,0,keyword,0.3473,True +What guarantees does BCNF provide?,keyword,keyword,1,1,1,1,1,0,0,1,definition,0.2504,True +Describe the DBMS architecture,keyword,keyword,0,0,0,0,1,1,1,1,other,0.2964,True +What is DDL used for?,keyword,keyword,1,1,1,1,0,1,1,0,keyword,0.3055,True +What is SQL used for?,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.2605,True +What is RAID used for in storage?,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.2359,True +Compare clustered and non-clustered indexes,comparison,comparison,1,0,1,1,1,0,0,0,comparison,0.6311,False +What is the difference between optimistic and pessimistic concurrency control?,comparison,comparison,1,1,1,1,0,0,0,0,comparison,0.2938,True +Contrast primary keys with foreign keys,comparison,comparison,1,1,1,1,1,1,1,1,definition,0.3346,True +Hash indexes versus tree-based indexes,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.4036,True +Compare strict and non-strict two-phase locking,comparison,comparison,1,1,1,1,1,1,1,1,explanatory,0.2864,True +Difference between logical and physical data independence,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.2532,True +Shadow paging versus log-based recovery,comparison,comparison,1,1,1,1,1,1,1,1,other,0.2181,True +Compare lossless and lossy decomposition in normalization,comparison,comparison,1,1,1,1,1,1,1,1,procedural,0.3055,True +What is a transaction?,definition,definition,1,1,1,1,1,1,1,1,keyword,0.2988,True +Define a relation in the relational model,definition,definition,1,1,1,1,1,1,1,1,definition,0.4219,True +What is serializability?,definition,definition,1,1,1,1,1,1,1,1,explanatory,0.2310,True +Define a functional dependency,definition,definition,1,1,1,1,1,1,1,1,definition,0.6872,False +What is a primary key?,definition,definition,1,1,1,1,1,1,1,1,definition,0.4580,True +What is a foreign key?,definition,definition,1,1,1,1,1,1,1,1,definition,0.6875,False +Define a candidate key,definition,definition,0,1,1,1,0,1,1,1,definition,0.6786,False +What is a schedule in transaction processing?,definition,definition,1,1,1,1,1,1,1,1,explanatory,0.2157,True +Why do we need transaction isolation?,explanatory,explanatory,1,1,1,1,1,1,1,1,explanatory,0.3429,True +Explain how deadlocks arise in concurrent transactions,explanatory,explanatory,1,1,1,1,1,1,1,1,explanatory,0.4966,True +Why is third normal form preferred over second normal form?,explanatory,explanatory,1,1,1,1,1,1,1,1,comparison,0.3144,True +Explain why B+ trees are preferred over binary search trees for disk storage,explanatory,explanatory,1,0,1,1,1,1,1,1,comparison,0.3472,True +Why does write-ahead logging ensure durability?,explanatory,explanatory,1,1,1,1,1,1,1,1,explanatory,0.4578,True +Explain the role of the recovery manager,explanatory,explanatory,1,1,1,1,1,0,1,1,explanatory,0.2260,True +Why is functional dependency important for normalization?,explanatory,explanatory,1,1,1,1,1,1,1,1,definition,0.3399,True +Explain how bucket overflow is handled in static hashing,explanatory,explanatory,1,1,1,1,1,1,1,1,procedural,0.2372,True +How to perform database normalization step by step,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.4463,True +Steps to recover a database after a crash,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.3723,True +Describe the algorithm for two-phase locking,procedural,procedural,1,1,1,1,0,1,1,0,explanatory,0.3436,True +How to build a B+ tree index from scratch,procedural,procedural,1,1,1,1,0,0,1,0,procedural,0.3960,True +What is ARIES recovery?,keyword,definition,1,1,1,1,0,1,1,0,keyword,0.2595,True +Define ACID compliance,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.3589,True +What is OLTP?,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.5880,False +What is OLAP?,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.3811,True +Describe DML in SQL,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.2618,True +What is JDBC used for?,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.2477,True +Explain LRU buffer replacement,keyword,keyword,1,1,1,1,1,1,1,1,other,0.2791,True +What is the CAP theorem?,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.2396,True +What is JSON storage in databases?,keyword,keyword,1,1,1,1,1,1,1,1,comparison,0.2299,True +What does NULL represent in SQL?,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.1953,True +What is ODBC?,keyword,keyword,0,1,1,0,1,1,1,1,keyword,0.3200,True +Explain CRUD operations,keyword,keyword,1,1,1,1,1,1,1,1,keyword,0.4236,True +Compare nested-loop join and hash join,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.3468,True +Compare static and dynamic hashing,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.3760,True +Difference between inner join and outer join,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.3216,True +Compare second normal form and third normal form,comparison,comparison,1,1,1,1,0,1,1,1,comparison,0.3991,True +Compare timestamp ordering and two-phase locking,comparison,comparison,1,1,1,1,1,1,1,1,explanatory,0.2794,True +Difference between snapshot isolation and serializable isolation,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.2692,True +Compare bitmap indexes and tree-based indexes,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.4020,True +Compare entity integrity and referential integrity,comparison,comparison,1,1,1,1,1,1,1,1,definition,0.2666,True +Difference between supertype and subtype in entity relationship diagrams,comparison,comparison,1,0,1,1,0,0,0,0,comparison,0.2458,True +Compare materialized views and regular views,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.3648,True +Compare horizontal partitioning and vertical partitioning,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.3593,True +Difference between checkpoint and savepoint,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.2360,True +Compare sequential file organization and hashed file organization,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.3497,True +Compare deferred and immediate database modification,comparison,comparison,1,1,1,1,1,1,1,1,comparison,0.2148,True +Difference between conflict serializability and view serializability,comparison,comparison,1,1,1,1,1,1,1,1,explanatory,0.2284,True +Compare aggregation and generalization in entity relationship modeling,comparison,comparison,1,1,1,1,1,1,1,1,procedural,0.2986,True +What is a superkey?,definition,definition,1,1,1,1,1,1,1,1,definition,0.4708,True +Define a composite key,definition,definition,1,1,1,1,1,1,1,1,definition,0.4600,True +What is a weak entity?,definition,definition,0,0,0,0,0,0,0,0,definition,0.3857,True +Define a multivalued dependency,definition,definition,1,1,1,1,1,1,1,1,definition,0.4285,True +What is an aggregate function?,definition,definition,1,1,1,1,1,1,1,1,definition,0.2119,True +Define a recoverable schedule,definition,definition,1,1,1,1,1,1,1,1,definition,0.2428,True +What is a cascading rollback?,definition,definition,1,1,1,1,1,1,1,1,keyword,0.2370,True +Define a hash function in indexing,definition,definition,1,1,1,1,1,1,1,1,comparison,0.2479,True +What is the projection operator in relational algebra?,definition,definition,1,1,1,1,1,1,1,1,procedural,0.2449,True +Define a deadlock victim,definition,definition,1,1,1,1,1,1,1,1,explanatory,0.2431,True +What is a participation constraint?,definition,definition,1,1,1,1,1,1,1,1,definition,0.4077,True +Define generalization in the entity relationship model,definition,definition,1,1,1,1,0,0,0,0,procedural,0.2580,True +Why does normalization eliminate update anomalies?,explanatory,explanatory,1,1,1,1,1,1,1,1,procedural,0.2625,True +Why is checkpointing important during recovery?,explanatory,explanatory,1,1,1,1,1,1,1,1,explanatory,0.2604,True +Why do databases use indexes despite write overhead?,explanatory,explanatory,1,1,1,1,0,0,1,1,explanatory,0.3408,True +Why does fuzzy checkpointing reduce stalls?,explanatory,explanatory,1,1,1,1,0,0,0,0,explanatory,0.2146,True +Why does snapshot isolation allow write skew?,explanatory,explanatory,1,1,1,1,1,1,1,1,explanatory,0.3290,True +Why is deadlock detection harder in distributed systems?,explanatory,explanatory,1,1,1,1,0,0,0,0,explanatory,0.3231,True +Why do databases use buffer pools?,explanatory,explanatory,1,1,1,1,1,1,1,1,other,0.3238,True +Explain how hash partitioning distributes data across nodes,explanatory,explanatory,1,1,1,1,1,1,1,1,comparison,0.2728,True +Why does query optimization estimate intermediate result sizes?,explanatory,explanatory,1,1,1,1,0,0,0,0,other,0.2982,True +Explain why the lost update problem occurs,explanatory,explanatory,1,1,1,1,1,1,1,1,explanatory,0.2956,True +Why are nested-loop joins inefficient on large tables?,explanatory,explanatory,1,1,1,1,0,1,0,0,other,0.2509,True +Why does the recovery manager need a redo log?,explanatory,explanatory,1,1,1,1,1,1,1,1,explanatory,0.3486,True +Explain how the buffer replacement policy affects performance,explanatory,explanatory,1,1,1,1,0,0,0,0,other,0.3502,True +Why do databases use sorted runs in external sort?,explanatory,explanatory,1,1,1,1,1,1,1,1,other,0.2211,True +Explain how foreign key constraints maintain referential integrity,explanatory,explanatory,1,1,1,1,1,1,1,1,definition,0.4980,True +Why does strict two-phase locking prevent cascading rollbacks?,explanatory,explanatory,1,1,1,1,1,1,1,1,explanatory,0.3612,True +Steps to perform a hash join,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.3451,True +How to detect a deadlock using a wait-for graph,procedural,procedural,1,1,1,1,1,1,1,1,explanatory,0.2855,True +How to recover from a system crash using log-based recovery,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.2955,True +Steps to compute the closure of a set of attributes,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.2768,True +How to convert an entity-relationship diagram to a relational schema,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.4230,True +Steps to estimate the cost of a sort-merge join,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.3135,True +Steps to perform external merge sort,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.2784,True +How to compute a candidate key from a set of functional dependencies,procedural,procedural,1,1,1,1,1,1,1,1,definition,0.3564,True +How to apply the chase algorithm to test lossless decomposition,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.3199,True +Steps to perform a redo pass during crash recovery,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.2779,True +How to insert a key into a B+ tree with node splits,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.3828,True +How to evaluate a query containing a where clause,procedural,procedural,1,1,1,1,0,0,0,0,procedural,0.2732,True +Steps to decompose a schema into Boyce-Codd normal form,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.4011,True +How to enforce serializability using timestamp ordering,procedural,procedural,1,1,1,1,1,1,1,1,explanatory,0.3020,True +Steps to grant and revoke access privileges in a database,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.2811,True +How to compute the relational division operator,procedural,procedural,1,1,1,1,1,1,1,1,procedural,0.3176,True +Database normalization in summary,other,other,1,1,1,1,1,1,1,1,procedural,0.3423,True +Concurrency and recovery interactions,other,other,1,1,1,1,1,1,1,1,procedural,0.2302,True +Storage hierarchy in databases,other,other,0,1,1,0,0,1,1,0,comparison,0.2528,True +Index selection during query planning,other,other,1,1,1,1,1,1,1,1,other,0.3115,True +Buffer management considerations,other,other,1,1,1,1,1,1,1,1,other,0.5574,False +Lock escalation behavior,other,other,1,1,1,1,1,1,1,1,other,0.4798,True +Transaction logging mechanism,other,other,1,1,1,1,1,1,1,1,explanatory,0.2688,True +The role of constraints in database design,other,other,1,1,1,1,1,1,1,1,procedural,0.3008,True diff --git a/eval/run_eval.py b/eval/run_eval.py new file mode 100644 index 00000000..17dfd808 --- /dev/null +++ b/eval/run_eval.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python3 +""" +eval/run_eval.py + +Evaluate the TokenSmith RAG pipeline on eval/questions.jsonl, comparing the +CompositeQueryPlanner (MultiHop -> Heuristic, the "optimizer") against a +no-op baseline planner that leaves cfg unchanged and never expands the query. + +The baseline and optimizer share the same artifacts (chunks, FAISS index, +BM25 index, embedding model) — only the planner wiring differs. The eval +calls `src.main.get_answer` directly with `is_test_mode=True` so no streaming +or markdown rendering happens. + +Usage: + python -m eval.run_eval # run both modes (default) + python -m eval.run_eval --baseline # run baseline only +""" +from __future__ import annotations + +import argparse +import csv +import json +import pathlib +import sys +from collections import defaultdict +from types import SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple + +# Ensure repo root is on sys.path when invoked as a script +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from src.config import RAGConfig +from src.instrumentation.logging import get_logger +from src.main import ANSWER_NOT_FOUND, get_answer +from src.planning.composite import CompositeQueryPlanner +from src.planning.cost_model import CostModelPlanner +from src.planning.heuristics import HeuristicQueryPlanner +from src.planning.learned_classifier import PrototypeClassifier +from src.planning.multihop import MultiHopQueryPlanner +from src.planning.noop import NoOpPlanner +from src.planning.planner import QueryPlanner +from src.ranking.ranker import EnsembleRanker +from src.retriever import ( + BM25Retriever, + FAISSRetriever, + IndexKeywordRetriever, + load_artifacts, +) + + +INDEX_PREFIX = "textbook_index" +QUESTIONS_PATH = REPO_ROOT / "eval" / "questions.jsonl" +RESULTS_PATH = REPO_ROOT / "eval" / "results.csv" +CONFIG_PATH = REPO_ROOT / "config" / "config.yaml" + + +def build_args() -> SimpleNamespace: + # get_answer reads args.system_prompt_mode and (via getattr) args.double_prompt + return SimpleNamespace( + system_prompt_mode="baseline", + double_prompt=False, + index_prefix=INDEX_PREFIX, + ) + + +def build_artifacts(cfg: RAGConfig, planner: QueryPlanner) -> Dict[str, Any]: + """ + Mirror run_chat_session's artifact setup but swap in an arbitrary planner. + """ + artifacts_dir = cfg.get_artifacts_directory() + faiss_idx, bm25_idx, chunks, sources, meta = load_artifacts( + artifacts_dir, INDEX_PREFIX + ) + retrievers: List[Any] = [ + FAISSRetriever(faiss_idx, cfg.embed_model), + BM25Retriever(bm25_idx), + ] + if cfg.ranker_weights.get("index_keywords", 0) > 0: + retrievers.append( + IndexKeywordRetriever(cfg.extracted_index_path, cfg.page_to_chunk_map_path) + ) + ranker = EnsembleRanker( + ensemble_method=cfg.ensemble_method, + weights=cfg.ranker_weights, + rrf_k=int(cfg.rrf_k), + ) + return { + "chunks": chunks, + "sources": sources, + "retrievers": retrievers, + "ranker": ranker, + "meta": meta, + "planner": planner, + } + + +def load_questions() -> List[Dict[str, Any]]: + items: List[Dict[str, Any]] = [] + with open(QUESTIONS_PATH) as f: + for line in f: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + return items + + +def retrieval_hit(chunks_info: Optional[List[Dict[str, Any]]], expected: List[str]) -> int: + if not chunks_info or not expected: + return 0 + contents = [str(c.get("content", "")).lower() for c in chunks_info] + for needle in expected: + n = str(needle).lower() + if any(n in content for content in contents): + return 1 + return 0 + + +def answer_hit(answer_text: str, gold_fragment: str) -> int: + if not answer_text or not gold_fragment: + return 0 + return 1 if gold_fragment.lower() in answer_text.lower() else 0 + + +def run_one( + question: str, + cfg: RAGConfig, + artifacts: Dict[str, Any], + args: SimpleNamespace, + logger, +) -> Tuple[str, List[Dict[str, Any]]]: + """ + Call get_answer in test mode. Returns (answer_text, chunks_info). + Handles the ANSWER_NOT_FOUND early-return path where get_answer yields + a bare string instead of the tuple. + """ + result = get_answer( + question=question, + cfg=cfg, + args=args, + logger=logger, + console=None, + artifacts=artifacts, + is_test_mode=True, + ) + if isinstance(result, tuple): + ans, chunks_info, _ = result + return ans or "", chunks_info or [] + # ANSWER_NOT_FOUND path — no chunks retrieved + return str(result), [] + + +# Empirically derived from N=116 eval (eval/results_v2_n116.csv): +# categories where the optimizer beat baseline → composite; where baseline +# won → noop. Categories not in the table fall back to the default planner. +COST_MODEL_ROUTING = { + "keyword": "composite", + "definition": "composite", + "procedural": "composite", + "other": "composite", + "comparison": "noop", + "explanatory": "noop", +} + + +def _build_cost_model(cfg: RAGConfig) -> CostModelPlanner: + composite = CompositeQueryPlanner( + cfg, + [MultiHopQueryPlanner(cfg), HeuristicQueryPlanner(cfg)], + ) + noop = NoOpPlanner(cfg) + table = { + "composite": composite, + "noop": noop, + } + routing = {cat: table[choice] for cat, choice in COST_MODEL_ROUTING.items()} + return CostModelPlanner( + cfg, + routing_table=routing, + default_planner=composite, + classifier=HeuristicQueryPlanner(cfg), + ) + + +PROTOTYPES_PATH = REPO_ROOT / "config" / "category_prototypes.yaml" + + +def _load_prototypes() -> Dict[str, List[str]]: + import yaml + with open(PROTOTYPES_PATH) as f: + data = yaml.safe_load(f) + if not isinstance(data, dict): + raise ValueError(f"Expected mapping at top level of {PROTOTYPES_PATH}") + return {cat: list(exs) for cat, exs in data.items()} + + +def _build_cost_model_lc(cfg: RAGConfig, confidence_threshold: float = 0.5) -> CostModelPlanner: + """Cost model with PrototypeClassifier + confidence-based fallback to NoOp.""" + from src.embedder import SentenceTransformer + composite = CompositeQueryPlanner( + cfg, + [MultiHopQueryPlanner(cfg), HeuristicQueryPlanner(cfg)], + ) + noop = NoOpPlanner(cfg) + table = {"composite": composite, "noop": noop} + routing = {cat: table[choice] for cat, choice in COST_MODEL_ROUTING.items()} + embedder = SentenceTransformer( + model_path=cfg.embed_model, + n_ctx=cfg.embedding_model_context_window, + ) + classifier = PrototypeClassifier( + embedder=embedder, + prototypes=_load_prototypes(), + softmax_temp=0.1, + ) + return CostModelPlanner( + cfg, + routing_table=routing, + default_planner=composite, + classifier=classifier, + confidence_threshold=confidence_threshold, + fallback_planner=noop, + ) + + +def main() -> None: + global QUESTIONS_PATH, RESULTS_PATH + parser = argparse.ArgumentParser(description="TokenSmith planner evaluation") + parser.add_argument( + "--baseline", + action="store_true", + help="Include the no-op baseline planner", + ) + parser.add_argument( + "--optimizer", + action="store_true", + help="Include the CompositeQueryPlanner (multi-hop + heuristic)", + ) + parser.add_argument( + "--cost-model", + dest="cost_model", + action="store_true", + help="Include the CostModelPlanner with regex classifier", + ) + parser.add_argument( + "--cost-model-lc", + dest="cost_model_lc", + action="store_true", + help="Include the CostModelPlanner with PrototypeClassifier (learned) + confidence fallback", + ) + parser.add_argument( + "--questions", + dest="questions", + default=None, + help=f"Path to questions JSONL (default: {QUESTIONS_PATH})", + ) + parser.add_argument( + "--output", + dest="output", + default=None, + help=f"Path to write results CSV (default: {RESULTS_PATH})", + ) + parser.add_argument( + "--confidence-threshold", + dest="confidence_threshold", + type=float, + default=0.5, + help="Confidence threshold for learned-classifier fallback (default: 0.5)", + ) + cli = parser.parse_args() + + # Allow --questions / --output to override module-level defaults + if cli.questions: + QUESTIONS_PATH = pathlib.Path(cli.questions).resolve() + if cli.output: + RESULTS_PATH = pathlib.Path(cli.output).resolve() + + # If no flag is passed, run all four modes. Otherwise run only the + # explicitly requested ones. + any_flag = cli.baseline or cli.optimizer or cli.cost_model or cli.cost_model_lc + run_baseline = cli.baseline or not any_flag + run_optimizer = cli.optimizer or not any_flag + run_cost_model = cli.cost_model or not any_flag + run_cost_model_lc = cli.cost_model_lc or not any_flag + + if not CONFIG_PATH.exists(): + print(f"ERROR: missing config at {CONFIG_PATH}", file=sys.stderr) + sys.exit(1) + + cfg = RAGConfig.from_yaml(CONFIG_PATH) + logger = get_logger() + args = build_args() + + print(f"Loading artifacts from {cfg.get_artifacts_directory()} ...") + + baseline_artifacts: Optional[Dict[str, Any]] = None + optimizer_artifacts: Optional[Dict[str, Any]] = None + cost_model_artifacts: Optional[Dict[str, Any]] = None + cost_model_lc_artifacts: Optional[Dict[str, Any]] = None + if run_baseline: + baseline_artifacts = build_artifacts(cfg, NoOpPlanner(cfg)) + if run_optimizer: + composite = CompositeQueryPlanner( + cfg, + [MultiHopQueryPlanner(cfg), HeuristicQueryPlanner(cfg)], + ) + optimizer_artifacts = build_artifacts(cfg, composite) + if run_cost_model: + cost_model_artifacts = build_artifacts(cfg, _build_cost_model(cfg)) + if run_cost_model_lc: + cost_model_lc_artifacts = build_artifacts( + cfg, _build_cost_model_lc(cfg, cli.confidence_threshold) + ) + + questions = load_questions() + print(f"Loaded {len(questions)} questions from {QUESTIONS_PATH}") + + # Use a dedicated HeuristicQueryPlanner to label the CSV regardless of + # which modes are run. This keeps the "planner_classification" column + # stable and lets us diagnose misclassifications even on baseline-only + # runs. + label_planner = HeuristicQueryPlanner(cfg) + + rows: List[Dict[str, Any]] = [] + for i, q in enumerate(questions, 1): + query = q["query"] + category = q.get("category", "") + expected = q.get("expected_chunks", []) + gold = q.get("gold_answer_fragment", "") + + classification = label_planner.classify(query) + + b_retr: Any = "" + o_retr: Any = "" + c_retr: Any = "" + cl_retr: Any = "" + b_ans: Any = "" + o_ans: Any = "" + c_ans: Any = "" + cl_ans: Any = "" + cl_conf: Any = "" + cl_fallback: Any = "" + cl_predicted: Any = "" + + print(f"\n[{i}/{len(questions)}] ({category}) {query}") + + if run_baseline and baseline_artifacts is not None: + print(" -- baseline --") + ans_b, chunks_b = run_one(query, cfg, baseline_artifacts, args, logger) + b_retr = retrieval_hit(chunks_b, expected) + b_ans = answer_hit(ans_b, gold) + print(f" retrieval_hit={b_retr} answer_hit={b_ans}") + + if run_optimizer and optimizer_artifacts is not None: + print(" -- optimizer --") + ans_o, chunks_o = run_one(query, cfg, optimizer_artifacts, args, logger) + o_retr = retrieval_hit(chunks_o, expected) + o_ans = answer_hit(ans_o, gold) + print(f" retrieval_hit={o_retr} answer_hit={o_ans}") + + if run_cost_model and cost_model_artifacts is not None: + print(" -- cost_model (regex) --") + ans_c, chunks_c = run_one(query, cfg, cost_model_artifacts, args, logger) + c_retr = retrieval_hit(chunks_c, expected) + c_ans = answer_hit(ans_c, gold) + print(f" retrieval_hit={c_retr} answer_hit={c_ans}") + + if run_cost_model_lc and cost_model_lc_artifacts is not None: + print(" -- cost_model (learned) --") + ans_cl, chunks_cl = run_one(query, cfg, cost_model_lc_artifacts, args, logger) + cl_retr = retrieval_hit(chunks_cl, expected) + cl_ans = answer_hit(ans_cl, gold) + decision = cost_model_lc_artifacts["planner"].last_decision + cl_conf = decision.get("confidence", "") + cl_fallback = decision.get("fallback", "") + cl_predicted = decision.get("category", "") + print(f" retrieval_hit={cl_retr} answer_hit={cl_ans} " + f"conf={cl_conf} fallback={cl_fallback} predicted={cl_predicted}") + + rows.append({ + "query": query, + "category": category, + "planner_classification": classification, + "baseline_retrieval_hit": b_retr, + "optimizer_retrieval_hit": o_retr, + "cost_model_retrieval_hit": c_retr, + "cost_model_lc_retrieval_hit": cl_retr, + "baseline_answer_hit": b_ans, + "optimizer_answer_hit": o_ans, + "cost_model_answer_hit": c_ans, + "cost_model_lc_answer_hit": cl_ans, + "cost_model_lc_predicted": cl_predicted, + "cost_model_lc_confidence": cl_conf, + "cost_model_lc_fallback": cl_fallback, + }) + + write_results(rows) + print_summary(rows, run_baseline, run_optimizer, run_cost_model, run_cost_model_lc) + + +def write_results(rows: List[Dict[str, Any]]) -> None: + RESULTS_PATH.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "query", + "category", + "planner_classification", + "baseline_retrieval_hit", + "optimizer_retrieval_hit", + "cost_model_retrieval_hit", + "cost_model_lc_retrieval_hit", + "baseline_answer_hit", + "optimizer_answer_hit", + "cost_model_answer_hit", + "cost_model_lc_answer_hit", + "cost_model_lc_predicted", + "cost_model_lc_confidence", + "cost_model_lc_fallback", + ] + with open(RESULTS_PATH, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(row) + print(f"\nWrote results to {RESULTS_PATH}") + + +def _rate(values: List[int]) -> float: + return (sum(values) / len(values)) if values else 0.0 + + +def print_summary( + rows: List[Dict[str, Any]], + run_baseline: bool, + run_optimizer: bool, + run_cost_model: bool, + run_cost_model_lc: bool = False, +) -> None: + buckets: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for r in rows: + buckets[r.get("category") or "unknown"].append(r) + + cols = [f"{'category':<12}", f"{'n':>4}"] + if run_baseline: + cols += [f"{'B retr':>8}", f"{'B ans':>8}"] + if run_optimizer: + cols += [f"{'O retr':>8}", f"{'O ans':>8}"] + if run_cost_model: + cols += [f"{'C retr':>8}", f"{'C ans':>8}"] + if run_cost_model_lc: + cols += [f"{'CL retr':>8}", f"{'CL ans':>8}"] + header = " ".join(cols) + + print() + print("=" * len(header)) + print("Hit rates by category") + print("=" * len(header)) + print(header) + print("-" * len(header)) + + def fmt_row(label: str, items: List[Dict[str, Any]]) -> str: + parts = [f"{label:<12}", f"{len(items):>4}"] + if run_baseline: + br = _rate([int(x.get("baseline_retrieval_hit") or 0) for x in items]) + ba = _rate([int(x.get("baseline_answer_hit") or 0) for x in items]) + parts += [f"{br:>8.2%}", f"{ba:>8.2%}"] + if run_optimizer: + orr = _rate([int(x.get("optimizer_retrieval_hit") or 0) for x in items]) + oa = _rate([int(x.get("optimizer_answer_hit") or 0) for x in items]) + parts += [f"{orr:>8.2%}", f"{oa:>8.2%}"] + if run_cost_model: + cr = _rate([int(x.get("cost_model_retrieval_hit") or 0) for x in items]) + ca = _rate([int(x.get("cost_model_answer_hit") or 0) for x in items]) + parts += [f"{cr:>8.2%}", f"{ca:>8.2%}"] + if run_cost_model_lc: + clr = _rate([int(x.get("cost_model_lc_retrieval_hit") or 0) for x in items]) + cla = _rate([int(x.get("cost_model_lc_answer_hit") or 0) for x in items]) + parts += [f"{clr:>8.2%}", f"{cla:>8.2%}"] + return " ".join(parts) + + for cat in sorted(buckets): + print(fmt_row(cat, buckets[cat])) + print("-" * len(header)) + print(fmt_row("ALL", rows)) + + +if __name__ == "__main__": + main() diff --git a/eval/verify_cost_model.py b/eval/verify_cost_model.py new file mode 100644 index 00000000..227f0afc --- /dev/null +++ b/eval/verify_cost_model.py @@ -0,0 +1,118 @@ +""" +Verify Plan A wiring: per row, the cost-model column must equal whichever +sub-planner (baseline or optimizer) the routing table picked for that row's +planner_classification. Any mismatch indicates broken delegation in +CostModelPlanner. + +Usage: + python eval/verify_cost_model.py eval/results.csv +""" +import csv +import sys +from collections import Counter + +# Must match COST_MODEL_ROUTING in eval/run_eval.py +ROUTING_TABLE = { + "keyword": "optimizer", + "definition": "optimizer", + "procedural": "optimizer", + "other": "optimizer", + "comparison": "baseline", + "explanatory": "baseline", +} +DEFAULT_ROUTE = "optimizer" + + +def expected(row, metric): + cat = row["planner_classification"] + route = ROUTING_TABLE.get(cat, DEFAULT_ROUTE) + return int(row[f"{route}_{metric}"]) + + +def main(): + path = sys.argv[1] if len(sys.argv) > 1 else "eval/results.csv" + with open(path) as f: + rows = list(csv.DictReader(f)) + + needed = {"baseline_retrieval_hit", "optimizer_retrieval_hit", + "cost_model_retrieval_hit", "baseline_answer_hit", + "optimizer_answer_hit", "cost_model_answer_hit", + "planner_classification"} + missing = needed - set(rows[0].keys()) + if missing: + print(f"FAIL: CSV missing columns: {sorted(missing)}") + sys.exit(2) + + # Skip rows where any required column is blank — happens if the eval + # was run with only one or two of the three modes enabled. + skipped = 0 + checked = 0 + retr_mismatch = [] + ans_mismatch = [] + route_counts = Counter() + + for i, r in enumerate(rows, 1): + if any(r.get(c) in ("", None) for c in needed): + skipped += 1 + continue + checked += 1 + cat = r["planner_classification"] + route = ROUTING_TABLE.get(cat, DEFAULT_ROUTE) + route_counts[route] += 1 + + exp_retr = expected(r, "retrieval_hit") + act_retr = int(r["cost_model_retrieval_hit"]) + if exp_retr != act_retr: + retr_mismatch.append((i, r["query"][:60], cat, route, exp_retr, act_retr)) + + exp_ans = expected(r, "answer_hit") + act_ans = int(r["cost_model_answer_hit"]) + if exp_ans != act_ans: + ans_mismatch.append((i, r["query"][:60], cat, route, exp_ans, act_ans)) + + print(f"=== verify_cost_model: {path} ===") + print(f" rows total: {len(rows)}") + print(f" rows checked: {checked}") + print(f" rows skipped: {skipped} (blank columns — partial-mode run)") + print(f" routes used: {dict(route_counts)}") + print() + print(f" retrieval_hit mismatches: {len(retr_mismatch)} / {checked}") + for i, q, cat, route, e, a in retr_mismatch[:10]: + print(f" row {i:>3} [{cat:>11s}→{route:>9s}] expected={e} actual={a} {q}") + if len(retr_mismatch) > 10: + print(f" ... and {len(retr_mismatch) - 10} more") + + print(f" answer_hit mismatches: {len(ans_mismatch)} / {checked}") + for i, q, cat, route, e, a in ans_mismatch[:10]: + print(f" row {i:>3} [{cat:>11s}→{route:>9s}] expected={e} actual={a} {q}") + if len(ans_mismatch) > 10: + print(f" ... and {len(ans_mismatch) - 10} more") + + print() + # Retrieval is deterministic — any mismatch is a real wiring bug. + # Answer-hit depends on LLM generation, which is non-deterministic on + # CPU llama.cpp even at temperature=0 (KV cache, OMP threading, fp + # accumulation). Treat answer-hit mismatches as informational unless + # they show directional bias. + if retr_mismatch: + print("RESULT: FAIL — retrieval mismatches indicate broken delegation.") + print(" Retrieval is deterministic; if cost-model doesn't match") + print(" the routed planner here, something is wired wrong.") + sys.exit(1) + + if ans_mismatch: + # Direction check: count cost-model wins vs losses against the routed planner. + wins = sum(1 for _, _, _, _, e, a in ans_mismatch if a > e) + losses = sum(1 for _, _, _, _, e, a in ans_mismatch if a < e) + print(f"RESULT: PASS (with LLM noise) — retrieval delegation is correct.") + print(f" Answer-hit jitter: {wins} cost-model wins, {losses} losses") + print(f" out of {checked} rows. Net effect on aggregate is small.") + print(f" Expected on CPU llama.cpp; not a wiring bug.") + sys.exit(0) + + print("RESULT: PASS — cost-model column matches routing rule on every row.") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/src/planning/cost_model.py b/src/planning/cost_model.py new file mode 100644 index 00000000..494ba7d0 --- /dev/null +++ b/src/planning/cost_model.py @@ -0,0 +1,123 @@ +""" +Cost-Model Query Planner +------------------------ +Routes each query to a sub-planner based on a per-category routing table. +The category is produced by an inner classifier (regex or learned). + +The routing table is empirically derived: for each category we use whichever +sub-planner had the higher answer-hit rate on the eval set. Categories not in +the table fall back to the configured default sub-planner. + +The classifier just needs to expose `classify_query(query) -> Classification`. +HeuristicQueryPlanner provides this via an adapter; PrototypeClassifier +implements it natively. This decouples routing from classifier shape so the +125% learned classifier swaps in without other code changes. + +Both `plan()` and `expand_queries()` are delegated to the chosen sub-planner +so multi-query expansion (e.g. multi-hop sub-questions) still happens on +routes that go to a composite planner. + +Confidence fallback (Phase 4 of 125% goal): +If the classifier's confidence on the top category is below +`confidence_threshold`, route to `fallback_planner` (defaults to NoOp via the +default_planner) regardless of the routing table. This trades incremental +gain for a guarantee that uncertain queries are never sent to a planner that +might hurt them — i.e., the cost model becomes monotonically non-worse than +baseline by construction. +""" +from __future__ import annotations + +from typing import Dict, List, Optional + +from src.config import RAGConfig +from src.planning.heuristics import HeuristicQueryPlanner +from src.planning.planner import QueryPlanner + + +class CostModelPlanner(QueryPlanner): + def __init__( + self, + base_cfg: RAGConfig, + routing_table: Dict[str, QueryPlanner], + default_planner: QueryPlanner, + classifier=None, + confidence_threshold: float = 0.0, + fallback_planner: Optional[QueryPlanner] = None, + ): + super().__init__(base_cfg) + if not routing_table: + raise ValueError("CostModelPlanner requires a non-empty routing_table.") + self.routing_table = dict(routing_table) + self.default_planner = default_planner + self.classifier = classifier if classifier is not None else HeuristicQueryPlanner(base_cfg) + self.confidence_threshold = float(confidence_threshold) + # If no explicit fallback planner is passed, fall back to default. + # In typical wiring `default_planner` is the composite optimizer; for + # safe-mode fallback the caller should pass NoOp explicitly. + self.fallback_planner = fallback_planner if fallback_planner is not None else default_planner + # Last-decision trace for observability + self.last_decision: Dict[str, str] = {} + + @property + def name(self) -> str: + routes = ",".join( + f"{cat}->{p.name}" for cat, p in sorted(self.routing_table.items()) + ) + suffix = "" + if self.confidence_threshold > 0: + suffix = f"|cf={self.confidence_threshold:.2f}->{self.fallback_planner.name}" + return f"CostModel[{routes}|default={self.default_planner.name}{suffix}]" + + def _classify(self, query: str): + """ + Adapter: support either classifier shape. + - Native protocol: classifier has `classify_query` returning Classification + - Legacy: classifier has `classify` returning a string (treated as confidence=1.0) + """ + if hasattr(self.classifier, "classify_query"): + return self.classifier.classify_query(query) + # Legacy fallback — wrap the string return so the rest of the planner + # only deals with the Classification shape. + from src.planning.learned_classifier import Classification + cat = self.classifier.classify(query) + return Classification(category=cat, confidence=1.0, all_scores={cat: 1.0}) + + def _route(self, query: str) -> tuple[str, QueryPlanner]: + result = self._classify(query) + category = result.category + confidence = float(result.confidence) + + if self.confidence_threshold > 0 and confidence < self.confidence_threshold: + chosen = self.fallback_planner + fallback = True + in_table = False + else: + chosen = self.routing_table.get(category, self.default_planner) + fallback = False + in_table = category in self.routing_table + + self.last_decision = { + "category": category, + "confidence": f"{confidence:.4f}", + "chosen": chosen.name, + "in_table": str(in_table), + "fallback": str(fallback), + } + return category, chosen + + def plan(self, query: str) -> RAGConfig: + category, chosen = self._route(query) + decision = self.last_decision + print( + f"[PLANNER] CostModelPlanner: category={category} " + f"conf={decision['confidence']} fallback={decision['fallback']} " + f"-> {chosen.name}" + ) + return chosen.plan(query) + + def expand_queries(self, query: str) -> List[str]: + # Re-routing is cheap (one classifier call) so it's fine to do + # unconditionally — keeps this method side-effect-correct when + # called before plan(). + _, chosen = self._route(query) + return chosen.expand_queries(query) diff --git a/src/planning/heuristics.py b/src/planning/heuristics.py index 94f78c72..f2e1350d 100644 --- a/src/planning/heuristics.py +++ b/src/planning/heuristics.py @@ -1,6 +1,8 @@ +import re from src.config import RAGConfig from copy import deepcopy +from src.planning.learned_classifier import Classification from src.planning.planner import QueryPlanner """ @@ -8,13 +10,26 @@ ----------------------- TODO: verify below assertions with data - Different query types have different needs: - • Definition queries → usually short answers, need fine-grained chunks (small tokens), + • Definition queries → usually short answers, need fine-grained chunks (small tokens), benefit from keyword match (BM25). - • Explanatory queries → broader answers, need larger spans (sections), + • Explanatory queries → broader answers, need larger spans (sections), benefit from semantic similarity (FAISS). - • Procedural queries (how-to, steps) → benefit from wider candidate pools and tag overlap, + • Procedural queries (how-to, steps) → benefit from wider candidate pools and tag overlap, since relevant steps may be scattered. + • Keyword queries (acronym-heavy, e.g. ACID / WAL / MVCC) → dominated by exact + token matches, benefit strongly from BM25. """ + +# All-caps 2-4 char tokens. Matches ACID, WAL, MVCC, RDBMS, SQL, etc. +_ACRONYM_PATTERN = re.compile(r"\b[A-Z]{2,4}\b") + +# Explanatory triggers: "why ...", "how does/do/is/are ...", "what causes ...", +# "explain ...". Anchored to the start of the lowercased query so that +# unrelated queries that merely contain the word "explain" in the middle +# don't get reclassified. +_EXPLANATORY_PATTERN = re.compile(r"^(why|how\s+(does|do|is|are)|what\s+causes|explain)\b") + + class HeuristicQueryPlanner(QueryPlanner): @property def name(self) -> str: @@ -25,32 +40,56 @@ def __init__(self, base_cfg: RAGConfig): self.base_cfg = deepcopy(base_cfg) def classify(self, query: str) -> str: + # Acronym check runs on the original casing, before lowercasing. + # It takes priority because queries like "what is ACID?" should be + # routed to BM25-heavy retrieval instead of the generic definition + # path. + if _ACRONYM_PATTERN.search(query): + return "keyword" q = query.lower() + if any(x in q for x in ["compare", "comparison", "difference between", "differences between", "vs", "versus", "contrast"]): + return "comparison" if any(x in q for x in ["what is", "define", "definition"]): return "definition" - if any(x in q for x in ["why", "explain", "because"]): - return "explanatory" if any(x in q for x in ["how to", "steps", "procedure", "algorithm"]): return "procedural" + if _EXPLANATORY_PATTERN.match(q) or "because" in q: + return "explanatory" return "other" + def classify_query(self, query: str) -> Classification: + """ + Classifier-protocol adapter. Returns a Classification with the regex + outcome and confidence=1.0 — the regex either matches a category or + falls through; there's no graded score. + """ + cat = self.classify(query) + return Classification(category=cat, confidence=1.0, all_scores={cat: 1.0}) + def plan(self, query: str) -> RAGConfig: kind = self.classify(query) cfg = deepcopy(self.base_cfg) - if kind == "definition": + if kind == "keyword": + cfg.ranker_weights = {"faiss": 0.1, "bm25": 0.9} + + elif kind == "comparison": + cfg.ranker_weights = {"faiss": 0.2, "bm25": 0.8} + + elif kind == "definition": cfg.ranker_weights = {"faiss": 0.3, "bm25": 0.7} elif kind == "explanatory": cfg.ranker_weights = {"faiss": 0.7, "bm25": 0.3} elif kind == "procedural": - cfg.pool_size = max(cfg.pool_size, cfg.top_k * 5) + cfg.num_candidates = max(cfg.num_candidates, cfg.top_k * 5) cfg.ranker_weights = {"faiss": 0.6, "bm25": 0.4} else: print("Unknown query type. Defaulting to explanatory.") cfg.ranker_weights = {"faiss": 0.7, "bm25": 0.3} + print(f"[PLANNER] HeuristicQueryPlanner: classified as {kind}, weights -> {cfg.ranker_weights}") self._log_decision(cfg) return cfg diff --git a/src/planning/learned_classifier.py b/src/planning/learned_classifier.py new file mode 100644 index 00000000..37790f0c --- /dev/null +++ b/src/planning/learned_classifier.py @@ -0,0 +1,99 @@ +""" +Learned Query Classifier +------------------------ +Zero-shot prototype classifier: embed a set of exemplar queries per category, +mean-pool to form a prototype vector, then at inference embed the query and +softmax cosine similarities to produce calibrated category probabilities. + +No training needed — uses the same embedder already loaded for retrieval. + +Used by CostModelPlanner as a drop-in alternative to the regex-based +HeuristicQueryPlanner. Both expose `classify_query(query) -> Classification` +so the cost-model doesn't care which one it gets. +""" +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from typing import Dict, List, Protocol + +import numpy as np + + +@dataclass +class Classification: + category: str + confidence: float # softmax probability of top class + all_scores: Dict[str, float] = field(default_factory=dict) + + +class Classifier(Protocol): + """Minimal interface CostModelPlanner relies on.""" + + def classify_query(self, query: str) -> Classification: ... + + +class PrototypeClassifier: + """ + Zero-shot prototype classifier via cosine similarity over query embeddings. + + Construction is O(num_categories * num_exemplars) embedder calls; classify + is one embedder call + a few dot products. The embedder is reused — pass + in the same SentenceTransformer instance the retrievers use. + """ + + def __init__( + self, + embedder, + prototypes: Dict[str, List[str]], + softmax_temp: float = 0.1, + ): + if not prototypes: + raise ValueError("PrototypeClassifier requires non-empty prototypes") + if softmax_temp <= 0: + raise ValueError("softmax_temp must be > 0") + self.embedder = embedder + self.softmax_temp = float(softmax_temp) + self.prototypes = {cat: list(exs) for cat, exs in prototypes.items()} + self.proto_embs: Dict[str, np.ndarray] = self._build_prototypes() + + def _build_prototypes(self) -> Dict[str, np.ndarray]: + out: Dict[str, np.ndarray] = {} + for cat, exemplars in self.prototypes.items(): + if not exemplars: + raise ValueError(f"Category {cat!r} has no exemplars") + embs = self.embedder.encode(list(exemplars), normalize=True) + embs = np.asarray(embs, dtype=np.float32) + mean = embs.mean(axis=0) + norm = float(np.linalg.norm(mean)) + if norm == 0.0: + raise ValueError(f"Zero-norm prototype for category {cat!r}") + out[cat] = mean / norm + return out + + def classify_query(self, query: str) -> Classification: + if not query or not query.strip(): + # Degenerate input — pick the first category at confidence 1/N. + n = len(self.proto_embs) + cat = next(iter(self.proto_embs)) + return Classification( + category=cat, + confidence=1.0 / n, + all_scores={c: 1.0 / n for c in self.proto_embs}, + ) + q_arr = self.embedder.encode([query], normalize=True) + q = np.asarray(q_arr, dtype=np.float32)[0] + # Re-normalize defensively in case the embedder didn't honor normalize=True. + nq = float(np.linalg.norm(q)) + if nq > 0: + q = q / nq + sims = {cat: float(q @ emb) for cat, emb in self.proto_embs.items()} + probs = self._softmax(sims) + top = max(probs, key=probs.get) + return Classification(category=top, confidence=probs[top], all_scores=probs) + + def _softmax(self, sims: Dict[str, float]) -> Dict[str, float]: + max_s = max(sims.values()) + exps = {cat: math.exp((s - max_s) / self.softmax_temp) for cat, s in sims.items()} + total = sum(exps.values()) + return {cat: e / total for cat, e in exps.items()} diff --git a/src/planning/multihop.py b/src/planning/multihop.py new file mode 100644 index 00000000..bcc7f85d --- /dev/null +++ b/src/planning/multihop.py @@ -0,0 +1,116 @@ +""" +Multi-hop Query Planner +----------------------- +Detects queries that ask about multiple concepts at once ("compare X and Y", +"what is A and why does B happen") and decomposes them into sub-questions. +Each sub-question is retrieved for independently; the callers in the pipeline +merge and deduplicate the resulting chunks before ranking. + +The planner still participates in the `QueryPlanner` cfg-mutation contract: +for detected multi-hop queries it widens the candidate pool so that merging +has enough headroom. The actual sub-question expansion is exposed via the +`expand_queries` hook on the base class. +""" +from __future__ import annotations + +import re +from copy import deepcopy +from typing import List + +from src.config import RAGConfig +from src.planning.planner import QueryPlanner +from src.query_enhancement import decompose_complex_query + + +_MULTIHOP_PATTERNS = [ + re.compile(r"\bcompare\b", re.IGNORECASE), + re.compile(r"\bcontrast\b", re.IGNORECASE), + re.compile(r"\bdifference(s)?\s+between\b", re.IGNORECASE), + re.compile(r"\bvs\.?\b", re.IGNORECASE), + re.compile(r"\bversus\b", re.IGNORECASE), + re.compile(r"\bboth\b", re.IGNORECASE), + re.compile(r"\beach\s+of\b", re.IGNORECASE), + re.compile(r"\band\s+(also|how|why|what|when|where)\b", re.IGNORECASE), +] + + +class MultiHopQueryPlanner(QueryPlanner): + @property + def name(self) -> str: + return "MultiHopPlanner" + + def __init__(self, base_cfg: RAGConfig, max_subquestions: int = 3): + super().__init__(base_cfg) + # RRF dilutes over too many retrieval sets — cap sub-question count + # so comparison queries don't get starved. + self.max_subquestions = max(1, int(max_subquestions)) + # Cache decompositions so plan() and expand_queries() don't call the + # LLM twice for the same question within a single turn. + self._decomposition_cache: dict[str, List[str]] = {} + + def _looks_multihop(self, query: str) -> bool: + if any(p.search(query) for p in _MULTIHOP_PATTERNS): + return True + # Two or more question marks => multiple direct questions joined. + if query.count("?") >= 2: + return True + return False + + def _decompose(self, query: str) -> List[str]: + if query in self._decomposition_cache: + return self._decomposition_cache[query] + + try: + raw = decompose_complex_query(query, self.base_cfg.gen_model) + except Exception: + raw = [query] + + seen: set[str] = set() + subs: List[str] = [] + for item in raw: + if not isinstance(item, str): + continue + cleaned = item.strip() + if not cleaned: + continue + key = cleaned.lower() + if key in seen: + continue + seen.add(key) + subs.append(cleaned) + + # Truncate post-hoc: the LLM frequently ignores the "at most N" + # instruction, so we enforce the cap on our side. + subs = subs[: self.max_subquestions] + + # If the LLM collapsed to nothing useful, fall back to the original. + if not subs or (len(subs) == 1 and subs[0].lower() == query.lower()): + subs = [query] + + self._decomposition_cache[query] = subs + return subs + + def plan(self, query: str) -> RAGConfig: + cfg = deepcopy(self.base_cfg) + + is_multihop = self._looks_multihop(query) + if is_multihop: + subs = self._decompose(query) + if len(subs) > 1: + # Widen the per-retriever pool so that each sub-question + # contributes candidates without starving the others. + cfg.num_candidates = max( + cfg.num_candidates, + cfg.top_k * max(4, len(subs) * 2), + ) + else: + subs = [query] + + print(f"[PLANNER] MultiHopQueryPlanner: multihop={is_multihop}, sub_questions={subs}") + self._log_decision(cfg) + return cfg + + def expand_queries(self, query: str) -> List[str]: + if not self._looks_multihop(query): + return [query] + return self._decompose(query) diff --git a/src/planning/noop.py b/src/planning/noop.py new file mode 100644 index 00000000..2366769c --- /dev/null +++ b/src/planning/noop.py @@ -0,0 +1,21 @@ +""" +No-op planner: returns the base cfg unchanged and never expands the query. +Useful as a baseline reference and as a building block for higher-level +planners (e.g. CostModelPlanner) that want to short-circuit to "do nothing" +on certain query categories. +""" +from __future__ import annotations + +from copy import deepcopy + +from src.config import RAGConfig +from src.planning.planner import QueryPlanner + + +class NoOpPlanner(QueryPlanner): + @property + def name(self) -> str: + return "NoOpPlanner" + + def plan(self, query: str) -> RAGConfig: + return deepcopy(self.base_cfg) diff --git a/src/query_enhancement.py b/src/query_enhancement.py index 4b0705a1..72f0350d 100644 --- a/src/query_enhancement.py +++ b/src/query_enhancement.py @@ -138,6 +138,7 @@ def decompose_complex_query( prompt = textwrap.dedent(f"""\ <|im_start|>system Break the following complex question into simple, single-step sub-questions. + Generate at most 3 sub-questions. Fewer is better if the query is simple. If the question is already simple, just output the original question. Output each sub-question on a new line. Do not provide explanations. <|im_end|> diff --git a/tests/test_cost_model_planner.py b/tests/test_cost_model_planner.py new file mode 100644 index 00000000..16fbedfc --- /dev/null +++ b/tests/test_cost_model_planner.py @@ -0,0 +1,264 @@ +"""Unit tests for CostModelPlanner routing. + +Stubs the sub-planners and the classifier so the test exercises only the +routing logic, not retrieval or LLM calls. +""" +from __future__ import annotations + +from typing import List + +import pytest + +from src.config import RAGConfig +from src.planning.cost_model import CostModelPlanner +from src.planning.heuristics import HeuristicQueryPlanner +from src.planning.planner import QueryPlanner + + +class _StubPlanner(QueryPlanner): + """Records calls; returns the cfg with a tag injected so tests can detect + which sub-planner produced the result.""" + + def __init__(self, base_cfg: RAGConfig, tag: str, expansion: List[str] | None = None): + super().__init__(base_cfg) + self.tag = tag + self._expansion = expansion or [f"{tag}_query"] + self.plan_calls: List[str] = [] + self.expand_calls: List[str] = [] + + @property + def name(self) -> str: + return f"Stub[{self.tag}]" + + def plan(self, query: str) -> RAGConfig: + self.plan_calls.append(query) + cfg = RAGConfig() + # Use ranker_weights as a side channel to identify the planner — the + # field already exists on RAGConfig so we don't need to extend it. + cfg.ranker_weights = {self.tag: 1.0} + return cfg + + def expand_queries(self, query: str) -> List[str]: + self.expand_calls.append(query) + return self._expansion + + +@pytest.fixture() +def base_cfg() -> RAGConfig: + return RAGConfig() + + +@pytest.fixture() +def stubs(base_cfg): + return { + "composite": _StubPlanner(base_cfg, "composite", expansion=["sub1", "sub2"]), + "noop": _StubPlanner(base_cfg, "noop"), + } + + +@pytest.fixture() +def planner(base_cfg, stubs): + routing = { + "keyword": stubs["composite"], + "definition": stubs["composite"], + "procedural": stubs["composite"], + "other": stubs["composite"], + "comparison": stubs["noop"], + "explanatory": stubs["noop"], + } + return CostModelPlanner( + base_cfg, + routing_table=routing, + default_planner=stubs["composite"], + classifier=HeuristicQueryPlanner(base_cfg), + ) + + +def test_comparison_routes_to_noop(planner, stubs): + cfg = planner.plan("Compare clustered and non-clustered indexes") + assert "noop" in cfg.ranker_weights + assert stubs["noop"].plan_calls == ["Compare clustered and non-clustered indexes"] + assert stubs["composite"].plan_calls == [] + + +def test_explanatory_routes_to_noop(planner, stubs): + # Avoid acronyms — the heuristic's acronym check fires before the + # explanatory check, so a query like "Why is BCNF..." routes as keyword. + query = "Why does write-ahead logging avoid losing committed transactions" + cfg = planner.plan(query) + assert "noop" in cfg.ranker_weights + assert stubs["noop"].plan_calls == [query] + + +def test_keyword_routes_to_composite(planner, stubs): + cfg = planner.plan("What is ACID?") + assert "composite" in cfg.ranker_weights + assert stubs["composite"].plan_calls == ["What is ACID?"] + assert stubs["noop"].plan_calls == [] + + +def test_definition_routes_to_composite(planner, stubs): + cfg = planner.plan("What is a foreign key") + assert "composite" in cfg.ranker_weights + + +def test_procedural_routes_to_composite(planner, stubs): + cfg = planner.plan("How to perform two-phase commit") + assert "composite" in cfg.ranker_weights + + +def test_unknown_category_uses_default(base_cfg, stubs): + """If the routing table doesn't list a category, fall back to default.""" + fallback = _StubPlanner(base_cfg, "fallback") + planner = CostModelPlanner( + base_cfg, + routing_table={"keyword": stubs["composite"]}, # only one entry + default_planner=fallback, + classifier=HeuristicQueryPlanner(base_cfg), + ) + cfg = planner.plan("just some random sentence") # → "other" + assert "fallback" in cfg.ranker_weights + assert fallback.plan_calls == ["just some random sentence"] + + +def test_expand_queries_delegated_to_chosen_planner(planner, stubs): + # Multi-hop-style expansion should pass through when routed to composite. + expanded = planner.expand_queries("What is ACID?") + assert expanded == ["sub1", "sub2"] + assert stubs["composite"].expand_calls == ["What is ACID?"] + assert stubs["noop"].expand_calls == [] + + +def test_expand_queries_uses_noop_for_routed_categories(planner, stubs): + # Comparison routes to noop; noop returns its own (single-element) expansion. + expanded = planner.expand_queries("Compare X and Y") + assert expanded == ["noop_query"] + assert stubs["noop"].expand_calls == ["Compare X and Y"] + + +def test_last_decision_records_route(planner): + planner.plan("What is ACID?") + assert planner.last_decision["category"] == "keyword" + assert "composite" in planner.last_decision["chosen"] + assert planner.last_decision["in_table"] == "True" + + +def test_empty_routing_table_rejected(base_cfg, stubs): + with pytest.raises(ValueError): + CostModelPlanner( + base_cfg, + routing_table={}, + default_planner=stubs["composite"], + ) + + +# --------------------------------------------------------------------------- +# Phase 4: confidence-based fallback +# --------------------------------------------------------------------------- + + +class _StubClassifier: + """Returns a Classification with a configurable category + confidence.""" + + def __init__(self, category: str, confidence: float): + self.category = category + self.confidence = confidence + self.calls: List[str] = [] + + def classify_query(self, query: str): + from src.planning.learned_classifier import Classification + self.calls.append(query) + return Classification( + category=self.category, + confidence=self.confidence, + all_scores={self.category: self.confidence}, + ) + + +def _make_planner(base_cfg, stubs, classifier, *, threshold, fallback=None): + routing = { + "keyword": stubs["composite"], + "comparison": stubs["noop"], + } + return CostModelPlanner( + base_cfg, + routing_table=routing, + default_planner=stubs["composite"], + classifier=classifier, + confidence_threshold=threshold, + fallback_planner=fallback, + ) + + +def test_high_confidence_uses_routing_table(base_cfg, stubs): + clf = _StubClassifier("keyword", confidence=0.95) + fallback = _StubPlanner(base_cfg, "fallback") + planner = _make_planner(base_cfg, stubs, clf, threshold=0.5, fallback=fallback) + + cfg = planner.plan("query") + assert "composite" in cfg.ranker_weights + assert stubs["composite"].plan_calls == ["query"] + assert fallback.plan_calls == [] + + +def test_low_confidence_routes_to_fallback(base_cfg, stubs): + clf = _StubClassifier("keyword", confidence=0.2) + fallback = _StubPlanner(base_cfg, "fallback") + planner = _make_planner(base_cfg, stubs, clf, threshold=0.5, fallback=fallback) + + cfg = planner.plan("query") + assert "fallback" in cfg.ranker_weights + assert fallback.plan_calls == ["query"] + assert stubs["composite"].plan_calls == [] + assert stubs["noop"].plan_calls == [] + + +def test_confidence_threshold_zero_disables_fallback(base_cfg, stubs): + """Default threshold=0 means fallback never triggers — backward compat.""" + clf = _StubClassifier("comparison", confidence=0.01) # very low + planner = _make_planner(base_cfg, stubs, clf, threshold=0.0) + + cfg = planner.plan("query") + # Should route via table (comparison -> noop), NOT fallback + assert "noop" in cfg.ranker_weights + assert stubs["noop"].plan_calls == ["query"] + + +def test_last_decision_includes_confidence_and_fallback(base_cfg, stubs): + clf = _StubClassifier("keyword", confidence=0.3) + fallback = _StubPlanner(base_cfg, "fallback") + planner = _make_planner(base_cfg, stubs, clf, threshold=0.5, fallback=fallback) + + planner.plan("query") + assert planner.last_decision["category"] == "keyword" + assert planner.last_decision["confidence"] == "0.3000" + assert planner.last_decision["fallback"] == "True" + assert "fallback" in planner.last_decision["chosen"] + + +def test_fallback_planner_defaults_to_default_planner(base_cfg, stubs): + """If no fallback_planner is given, low confidence routes to default.""" + clf = _StubClassifier("keyword", confidence=0.1) + planner = _make_planner(base_cfg, stubs, clf, threshold=0.5) # no fallback arg + + cfg = planner.plan("query") + # default_planner is stubs["composite"] in _make_planner + assert "composite" in cfg.ranker_weights + + +def test_legacy_classifier_still_works(base_cfg, stubs): + """Plain HeuristicQueryPlanner (no classify_query) should still work + via the duck-typed adapter in _classify.""" + + class LegacyClassifier: + def classify(self, query: str) -> str: + return "keyword" + + planner = CostModelPlanner( + base_cfg, + routing_table={"keyword": stubs["composite"]}, + default_planner=stubs["noop"], + classifier=LegacyClassifier(), + ) + cfg = planner.plan("query") + assert "composite" in cfg.ranker_weights diff --git a/tests/test_learned_classifier.py b/tests/test_learned_classifier.py new file mode 100644 index 00000000..957e3ce3 --- /dev/null +++ b/tests/test_learned_classifier.py @@ -0,0 +1,154 @@ +"""Unit tests for PrototypeClassifier. + +Stubs the embedder so the test never loads the real Qwen3 GGUF model. +The stub maps each text to a deterministic vector based on its content, +which is enough to exercise the prototype-build / softmax / argmax paths. +""" +from __future__ import annotations + +from typing import List, Sequence, Union + +import numpy as np +import pytest + +from src.planning.learned_classifier import Classification, PrototypeClassifier + + +# Map each "category tag" (a substring inside the exemplar) to a unit basis +# vector. This makes the cosine similarity exactly 1.0 between query and +# matching prototype, and exactly 0.0 between query and non-matching. +_CATEGORY_TAGS = ["alpha", "beta", "gamma", "delta"] + + +class _StubEmbedder: + """Deterministic embedder: tags inside the text pick a basis vector. + + Each text contributes one basis vector per matching tag (averaged if + multiple). If no tag is found, returns the zero vector — which lets us + test the degenerate case. + """ + + def __init__(self): + self.dim = len(_CATEGORY_TAGS) + self.basis = {tag: np.eye(self.dim)[i] for i, tag in enumerate(_CATEGORY_TAGS)} + self.calls: List[str] = [] + + def encode( + self, + texts: Union[str, Sequence[str]], + normalize: bool = False, + **kwargs, + ) -> np.ndarray: + if isinstance(texts, str): + texts = [texts] + out = np.zeros((len(texts), self.dim), dtype=np.float32) + for i, t in enumerate(texts): + self.calls.append(t) + tags = [tag for tag in _CATEGORY_TAGS if tag in t.lower()] + if not tags: + continue + v = np.mean([self.basis[tag] for tag in tags], axis=0) + n = float(np.linalg.norm(v)) + if normalize and n > 0: + v = v / n + out[i] = v + return out + + +@pytest.fixture() +def embedder() -> _StubEmbedder: + return _StubEmbedder() + + +@pytest.fixture() +def prototypes() -> dict: + # Two exemplars per category, each tagged unambiguously. + return { + "alpha": ["alpha example one", "alpha example two"], + "beta": ["beta example one", "beta example two"], + "gamma": ["gamma example one", "gamma example two"], + } + + +@pytest.fixture() +def clf(embedder, prototypes) -> PrototypeClassifier: + return PrototypeClassifier(embedder, prototypes, softmax_temp=0.1) + + +def test_prototypes_built_at_construction(embedder, prototypes): + PrototypeClassifier(embedder, prototypes) + # Should have embedded each exemplar exactly once. + assert len(embedder.calls) == sum(len(v) for v in prototypes.values()) + + +def test_classify_returns_correct_category(clf): + result = clf.classify_query("alpha alpha test query") + assert isinstance(result, Classification) + assert result.category == "alpha" + + +def test_confidence_bounded_between_zero_and_one(clf): + result = clf.classify_query("beta beta test query") + assert 0.0 < result.confidence <= 1.0 + + +def test_all_scores_sum_to_one(clf): + result = clf.classify_query("gamma test") + total = sum(result.all_scores.values()) + assert pytest.approx(total, abs=1e-5) == 1.0 + + +def test_strongest_signal_dominates(clf): + # With softmax_temp=0.1, a dot product of 1.0 vs 0.0 should produce a + # near-degenerate distribution heavily favoring the matching prototype. + result = clf.classify_query("alpha") + assert result.confidence > 0.95 + + +def test_ambiguous_input_yields_low_confidence(clf): + # "alpha beta" matches two prototypes equally — confidence should be + # closer to 0.5 than to 1.0. + result = clf.classify_query("alpha beta") + assert 0.3 < result.confidence < 0.7 + + +def test_empty_query_returns_uniform(clf): + result = clf.classify_query("") + n = 3 # three categories in fixture + assert pytest.approx(result.confidence, abs=1e-5) == 1.0 / n + assert sorted(result.all_scores.keys()) == ["alpha", "beta", "gamma"] + + +def test_zero_norm_prototype_rejected(embedder): + # If all exemplars for a category have no tags, the mean is zero. + bad_prototypes = { + "alpha": ["alpha example"], + "blank": ["nothing matches here"], + } + with pytest.raises(ValueError, match="Zero-norm prototype"): + PrototypeClassifier(embedder, bad_prototypes) + + +def test_empty_prototypes_rejected(embedder): + with pytest.raises(ValueError): + PrototypeClassifier(embedder, {}) + + +def test_empty_exemplar_list_rejected(embedder): + with pytest.raises(ValueError): + PrototypeClassifier(embedder, {"alpha": []}) + + +def test_invalid_temp_rejected(embedder, prototypes): + with pytest.raises(ValueError): + PrototypeClassifier(embedder, prototypes, softmax_temp=0.0) + with pytest.raises(ValueError): + PrototypeClassifier(embedder, prototypes, softmax_temp=-1.0) + + +def test_classify_does_not_mutate_state(clf): + # Calling classify_query twice with the same input should produce the + # same result — no hidden state should accumulate. + a = clf.classify_query("alpha test") + b = clf.classify_query("alpha test") + assert a == b diff --git a/tests/test_multihop_planner.py b/tests/test_multihop_planner.py new file mode 100644 index 00000000..7ccda261 --- /dev/null +++ b/tests/test_multihop_planner.py @@ -0,0 +1,59 @@ +"""Unit tests for MultiHopQueryPlanner decomposition. + +Uses monkeypatch to replace the LLM-backed decompose call with a stub so +these tests stay artifact-free and deterministic. +""" +from __future__ import annotations + +import pytest + +from src.config import RAGConfig +from src.planning.multihop import MultiHopQueryPlanner + + +@pytest.fixture() +def base_cfg() -> RAGConfig: + return RAGConfig() + + +def test_multihop_truncates_to_cap(base_cfg, monkeypatch): + planner = MultiHopQueryPlanner(base_cfg, max_subquestions=2) + + # Pretend the LLM returned 4 sub-questions; the planner must cap to 2. + monkeypatch.setattr( + "src.planning.multihop.decompose_complex_query", + lambda query, model_path, **kwargs: [ + "what is FAISS", + "what is BM25", + "how does FAISS rank", + "how does BM25 rank", + ], + ) + subs = planner._decompose("compare FAISS and BM25 retrieval performance") + assert len(subs) <= 2, f"got {len(subs)} subquestions: {subs}" + assert subs == ["what is FAISS", "what is BM25"] + + +def test_multihop_fallback_on_empty_llm_output(base_cfg, monkeypatch): + planner = MultiHopQueryPlanner(base_cfg, max_subquestions=2) + monkeypatch.setattr( + "src.planning.multihop.decompose_complex_query", + lambda query, model_path, **kwargs: [], + ) + query = "compare FAISS and BM25" + subs = planner._decompose(query) + assert subs == [query], "must fall back to original query, never return zero subs" + + +def test_multihop_default_cap_is_three(base_cfg): + assert MultiHopQueryPlanner(base_cfg).max_subquestions == 3 + + +def test_multihop_dedupes_before_capping(base_cfg, monkeypatch): + # Dedupe must happen before the cap, otherwise duplicates eat the budget. + planner = MultiHopQueryPlanner(base_cfg, max_subquestions=2) + monkeypatch.setattr( + "src.planning.multihop.decompose_complex_query", + lambda query, model_path, **kwargs: ["x", "X", "x", "y", "z"], + ) + assert planner._decompose("some query") == ["x", "y"]