Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions config/category_prototypes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Category prototypes for PrototypeClassifier
# ---------------------------------------------
# Each list is averaged into a single prototype embedding at planner-build
# time. Exemplars are sourced from eval/questions.jsonl so the prototype
# space matches the distribution we evaluate against.
#
# Add 3-5 exemplars per category. More exemplars = smoother prototype but
# diminishing returns above ~5. If a category looks under-represented in
# practice (e.g. low confidence on legitimate queries), add 1-2 more.
#
# Acronym hygiene: keyword exemplars MUST contain an all-caps 2-4 char token
# (matches _ACRONYM_PATTERN in heuristics.py). Other categories MUST NOT —
# bleed will collapse the prototype space.

keyword:
- "What is ACID?"
- "Define MVCC"
- "Explain WAL-based recovery"
- "What does OLTP mean"
- "What is CRUD in databases"

comparison:
- "Compare clustered and non-clustered indexes"
- "Difference between B+ tree and hash index"
- "Compare second normal form and third normal form"
- "Difference between optimistic and pessimistic concurrency control"
- "Compare row-store and column-store databases"

definition:
- "What is a candidate key"
- "Define functional dependency"
- "What is a schedule in concurrency control"
- "Define referential integrity"
- "What is a foreign key"

explanatory:
- "Why does normalization reduce redundancy"
- "How does write-ahead logging ensure durability"
- "Explain how deadlocks arise in transactions"
- "Why do databases use indexes despite write overhead"
- "How does two-phase commit work"

procedural:
- "How to build a B+ tree index from scratch"
- "How to perform database normalization"
- "Steps to recover from a transaction failure"
- "How to design a relational schema for a library system"
- "Procedure for query optimization in a relational engine"

other:
- "List the disadvantages of using a database"
- "What are the components of a DBMS"
- "Lock escalation behavior"
- "Buffer management considerations"
- "Index selection during query planning"
35 changes: 35 additions & 0 deletions eval/breakdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import csv
import collections
import sys

path = sys.argv[1] if len(sys.argv) > 1 else "eval/results_v2_n116.csv"
rows = list(csv.DictReader(open(path)))

def rate(rs, c):
return sum(int(r[c]) for r in rs) / len(rs) * 100 if rs else 0.0

cols = [
"baseline_retrieval_hit",
"optimizer_retrieval_hit",
"baseline_answer_hit",
"optimizer_answer_hit",
]

print(f"=== N={len(rows)} ({path}) ===")
for c in cols:
print(f" {c:32s} {rate(rows, c):6.2f}%")
print(f" delta retrieval: {rate(rows,'optimizer_retrieval_hit')-rate(rows,'baseline_retrieval_hit'):+.2f}")
print(f" delta answer: {rate(rows,'optimizer_answer_hit')-rate(rows,'baseline_answer_hit'):+.2f}")
print()

by_cat = collections.defaultdict(list)
for r in rows:
by_cat[r["category"]].append(r)

print(f" {'cat':12s} {'n':>3} {'b_ret':>7} {'o_ret':>7} {'b_ans':>7} {'o_ans':>7} {'dret':>6} {'dans':>6}")
for cat, rs in sorted(by_cat.items()):
br = rate(rs, "baseline_retrieval_hit")
orr = rate(rs, "optimizer_retrieval_hit")
ba = rate(rs, "baseline_answer_hit")
oa = rate(rs, "optimizer_answer_hit")
print(f" {cat:12s} {len(rs):>3} {br:6.1f}% {orr:6.1f}% {ba:6.1f}% {oa:6.1f}% {orr-br:+5.1f} {oa-ba:+5.1f}")
78 changes: 78 additions & 0 deletions eval/compare_runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Compare two eval CSVs side-by-side per category.

Helps spot which per-category numbers are stable across runs vs. which are
swinging due to LLM generation noise. The routing table baked into the cost
model was derived from a specific run; if a category's baseline/optimizer
gap flips sign in a later run, the routing decision for that category is
noise, not signal.

Usage:
python eval/compare_runs.py eval/results_v2_n116.csv eval/results_v3_n116_with_cost_model.csv
"""
import csv
import collections
import sys


def load(path):
with open(path) as f:
return list(csv.DictReader(f))


def rate(rs, c):
return sum(int(r[c] or 0) for r in rs) / len(rs) * 100 if rs else 0.0


def by_cat(rows):
d = collections.defaultdict(list)
for r in rows:
d[r["category"]].append(r)
return d


def main():
if len(sys.argv) < 3:
print("usage: compare_runs.py RUN_A.csv RUN_B.csv")
sys.exit(2)
path_a, path_b = sys.argv[1], sys.argv[2]
rows_a = load(path_a)
rows_b = load(path_b)
cats_a, cats_b = by_cat(rows_a), by_cat(rows_b)
cats = sorted(set(cats_a) | set(cats_b))

label_a = path_a.split("/")[-1]
label_b = path_b.split("/")[-1]

print(f"A = {label_a} (N={len(rows_a)})")
print(f"B = {label_b} (N={len(rows_b)})")
print()
print(f" {'cat':12s} {'n':>3} {'A b_ans':>8} {'B b_ans':>8} {'A o_ans':>8} {'B o_ans':>8} {'A gap':>7} {'B gap':>7}")
print(f" {'-'*12} {'-'*3} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*7} {'-'*7}")
for cat in cats:
ra = cats_a.get(cat, [])
rb = cats_b.get(cat, [])
n = max(len(ra), len(rb))
a_b = rate(ra, "baseline_answer_hit") if ra else 0.0
b_b = rate(rb, "baseline_answer_hit") if rb else 0.0
a_o = rate(ra, "optimizer_answer_hit") if ra else 0.0
b_o = rate(rb, "optimizer_answer_hit") if rb else 0.0
a_gap = a_o - a_b
b_gap = b_o - b_b
print(f" {cat:12s} {n:>3} {a_b:7.1f}% {b_b:7.1f}% {a_o:7.1f}% {b_o:7.1f}% {a_gap:+6.1f} {b_gap:+6.1f}")

print()
print(" Legend: gap = optimizer − baseline. If a category's gap sign flips between A and B,")
print(" the routing decision for that category is within LLM-noise range, not robust signal.")
print()
if "cost_model_answer_hit" in (rows_b[0] if rows_b else {}):
cm = rate(rows_b, "cost_model_answer_hit")
bb = rate(rows_b, "baseline_answer_hit")
bo = rate(rows_b, "optimizer_answer_hit")
print(f" B has cost_model column: cm_ans={cm:.2f}% (vs B baseline {bb:.2f}%, B optimizer {bo:.2f}%)")
print(f" Δ cost-model − baseline: {cm-bb:+.2f}")
print(f" Δ cost-model − optimizer: {cm-bo:+.2f}")


if __name__ == "__main__":
main()
119 changes: 119 additions & 0 deletions eval/cost_model_sanity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Plan C: project what a cost-model planner *would* score, by post-processing
an existing baseline+optimizer eval CSV. No new eval run required.

For each row, pick baseline_* or optimizer_* hits based on a routing table
keyed on planner_classification (the heuristic's predicted category). Report
the projected hit rate alongside baseline-only and optimizer-only rates.

We also report an "oracle" projection routed by the gold `category` column,
to bound how much of any gap is misclassification vs. table miscalibration.
"""
import csv
import collections
import sys

# Empirical routing table from N=116 run (eval/results_v2_n116.csv):
# keyword +10.0 → composite wins → optimizer
# definition +10.0 → optimizer
# procedural +5.0 → optimizer
# other +12.5 → optimizer
# comparison -4.2 → baseline
# explanatory -4.2 → baseline
ROUTING_TABLE = {
"keyword": "optimizer",
"definition": "optimizer",
"procedural": "optimizer",
"other": "optimizer",
"comparison": "baseline",
"explanatory": "baseline",
}
DEFAULT_ROUTE = "optimizer"


def pick(row, mode_col_prefix, route):
col = f"{route}_{mode_col_prefix}"
return int(row[col])


def project(rows, route_key):
"""Return per-row picks (retrieval_hit, answer_hit) given a routing key."""
out = []
for r in rows:
cat = r[route_key]
choice = ROUTING_TABLE.get(cat, DEFAULT_ROUTE)
out.append({
**r,
"cost_model_retrieval_hit": pick(r, "retrieval_hit", choice),
"cost_model_answer_hit": pick(r, "answer_hit", choice),
"cost_model_choice": choice,
})
return out


def rate(rs, c):
return sum(int(r[c]) for r in rs) / len(rs) * 100 if rs else 0.0


def report(label, rows):
print(f"=== {label} (N={len(rows)}) ===")
print(f" baseline_retrieval_hit {rate(rows, 'baseline_retrieval_hit'):6.2f}%")
print(f" optimizer_retrieval_hit {rate(rows, 'optimizer_retrieval_hit'):6.2f}%")
print(f" cost_model_retrieval_hit {rate(rows, 'cost_model_retrieval_hit'):6.2f}%")
print(f" baseline_answer_hit {rate(rows, 'baseline_answer_hit'):6.2f}%")
print(f" optimizer_answer_hit {rate(rows, 'optimizer_answer_hit'):6.2f}%")
print(f" cost_model_answer_hit {rate(rows, 'cost_model_answer_hit'):6.2f}%")
print(f" delta vs baseline (ans) {rate(rows,'cost_model_answer_hit')-rate(rows,'baseline_answer_hit'):+6.2f}")
print(f" delta vs optimizer (ans) {rate(rows,'cost_model_answer_hit')-rate(rows,'optimizer_answer_hit'):+6.2f}")
print()


def per_category(rows):
by_cat = collections.defaultdict(list)
for r in rows:
by_cat[r["category"]].append(r)
print(f" {'cat':12s} {'n':>3} {'b_ans':>7} {'o_ans':>7} {'cm_ans':>7} {'choice':>10}")
for cat, rs in sorted(by_cat.items()):
ba = rate(rs, "baseline_answer_hit")
oa = rate(rs, "optimizer_answer_hit")
cma = rate(rs, "cost_model_answer_hit")
# Most-common routing choice in this gold category
choices = collections.Counter(r["cost_model_choice"] for r in rs)
common_choice = choices.most_common(1)[0][0]
print(f" {cat:12s} {len(rs):>3} {ba:6.1f}% {oa:6.1f}% {cma:6.1f}% {common_choice:>10}")
print()


def confusion(rows):
"""Show how often the heuristic classifier disagrees with the gold label."""
print(" classifier confusion (gold → predicted):")
pairs = collections.Counter((r["category"], r["planner_classification"]) for r in rows)
misses = [(g, p, n) for (g, p), n in pairs.items() if g != p]
if not misses:
print(" (no mismatches)")
return
for g, p, n in sorted(misses, key=lambda x: -x[2]):
print(f" {g:12s} → {p:12s} n={n}")
print()


def main():
path = sys.argv[1] if len(sys.argv) > 1 else "eval/results_v2_n116.csv"
with open(path) as f:
rows = list(csv.DictReader(f))

# Production cost model: route by heuristic classification (the planner
# has no access to the gold category at inference time).
prod = project(rows, "planner_classification")
report(f"Cost model (route by planner_classification) — {path}", prod)
per_category(prod)

# Oracle: route by gold category. Upper bound; gap to prod = classifier loss.
oracle = project(rows, "category")
report(f"Oracle cost model (route by gold category) — {path}", oracle)

confusion(rows)


if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions eval/headroom_analysis.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
=== Cost model (route by planner_classification) — eval/results_corpus1_final.csv (N=116) ===
baseline_retrieval_hit 95.69%
optimizer_retrieval_hit 95.69%
cost_model_retrieval_hit 98.28%
baseline_answer_hit 83.62%
optimizer_answer_hit 87.93%
cost_model_answer_hit 87.93%
delta vs baseline (ans) +4.31
delta vs optimizer (ans) +0.00

cat n b_ans o_ans cm_ans choice
comparison 24 87.5% 87.5% 87.5% baseline
definition 20 85.0% 90.0% 90.0% optimizer
explanatory 24 75.0% 75.0% 75.0% baseline
keyword 20 85.0% 95.0% 95.0% optimizer
other 8 87.5% 100.0% 100.0% optimizer
procedural 20 85.0% 90.0% 90.0% optimizer

=== Oracle cost model (route by gold category) — eval/results_corpus1_final.csv (N=116) ===
baseline_retrieval_hit 95.69%
optimizer_retrieval_hit 95.69%
cost_model_retrieval_hit 98.28%
baseline_answer_hit 83.62%
optimizer_answer_hit 87.93%
cost_model_answer_hit 87.93%
delta vs baseline (ans) +4.31
delta vs optimizer (ans) +0.00

classifier confusion (gold → predicted):
keyword → definition n=1

Loading