-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli.py
More file actions
114 lines (108 loc) · 4.99 KB
/
cli.py
File metadata and controls
114 lines (108 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import glob
import json
import typer
from rich.console import Console
from .core import ProofCite
from .config import load_settings
app = typer.Typer(add_completion=False)
console = Console()
@app.command()
def ask(
docs: str = typer.Option(..., "--docs", help="Glob for documents, e.g. 'examples/data/*.txt'"),
q: str = typer.Option(None, "--q", help="Question"),
batch: str = typer.Option(None, "--batch", help="Path to newline-delimited questions for batch mode"),
k: int = typer.Option(5, help="Top-k lines to retrieve"),
threshold: float = typer.Option(0.35, help="Minimum cosine similarity to answer"),
rerank: str = typer.Option("none", help="Reranking: none|bm25|hybrid"),
span_max_gap: int = typer.Option(0, help="Merge citations into spans when lines within this gap"),
segment: str = typer.Option("line", help="Segmentation: line|paragraph|sentence|token"),
token_chunk_size: int = typer.Option(80, help="Token chunk size when segment=token"),
retriever: str = typer.Option("deterministic", help="Retriever: deterministic|embedding"),
allow_paths: str = typer.Option(None, "--allow-paths", help="Regex of allowed citation paths"),
deny_paths: str = typer.Option(None, "--deny-paths", help="Regex of disallowed citation paths"),
json_out: bool = typer.Option(False, "--json", help="Emit JSON to stdout"),
):
# Apply defaults from settings if CLI args not provided
s = load_settings()
docs = docs or s.docs
threshold = threshold if threshold is not None else s.threshold
rerank = rerank or s.rerank
span_max_gap = span_max_gap if span_max_gap is not None else s.span_max_gap
segment = segment or s.segment
token_chunk_size = token_chunk_size or s.token_chunk_size
paths = sorted(glob.glob(docs))
if retriever == "deterministic":
pc = ProofCite(segment=segment, token_chunk_size=token_chunk_size)
elif retriever == "embedding":
from .embedding import EmbeddingRetriever
pc = EmbeddingRetriever(segment=segment, token_chunk_size=token_chunk_size)
else:
raise typer.BadParameter("retriever must be deterministic or embedding")
pc.add_documents(paths)
pc.build()
if (q is None) == (batch is None):
console.print("[bold red]Provide exactly one of --q or --batch[/]")
raise typer.Exit(2)
if batch is not None:
with open(batch, "r", encoding="utf-8") as f:
qs = [line.strip() for line in f if line.strip()]
results = []
for qi in qs:
ans = pc.ask(qi, k=k, threshold=threshold, rerank=rerank, span_max_gap=span_max_gap, allowed_paths_regex=allow_paths, denied_paths_regex=deny_paths)
results.append({
"q": qi,
"answer": ans.answer,
"unverifiable": ans.unverifiable,
"max_score": ans.max_score,
"threshold": ans.threshold,
"citations": [
{"path": c.path, "line_no": c.line_no, "text": c.text, "score": c.score}
for c in ans.citations
],
"spans": ans.spans,
})
if json_out:
print(json.dumps({"results": results}, ensure_ascii=False))
# Return nonzero if any unverifiable
if any(r["unverifiable"] for r in results):
raise typer.Exit(1)
return
for r in results:
if r["unverifiable"]:
console.print(f"[bold red]Unverifiable[/] (q={r['q']})")
else:
console.print(f"[bold]Q:[/] {r['q']}\n[bold]A:[/] {r['answer']}")
for c in r["citations"]:
console.print(f" • {c['path']}:{c['line_no']} score={c['score']:.3f}")
# Exit 1 if any unverifiable
if any(r["unverifiable"] for r in results):
raise typer.Exit(1)
return
ans = pc.ask(q, k=k, threshold=threshold, rerank=rerank, span_max_gap=span_max_gap, allowed_paths_regex=allow_paths, denied_paths_regex=deny_paths)
if json_out:
payload = {
"answer": ans.answer,
"unverifiable": ans.unverifiable,
"max_score": ans.max_score,
"threshold": ans.threshold,
"citations": [
{"path": c.path, "line_no": c.line_no, "text": c.text, "score": c.score}
for c in ans.citations
],
"spans": ans.spans,
}
print(json.dumps(payload, ensure_ascii=False))
if ans.unverifiable:
raise typer.Exit(1)
return
if ans.unverifiable:
console.print(f"[bold red]Unverifiable[/] (max_score={ans.max_score:.3f} < threshold={ans.threshold})")
raise typer.Exit(1)
console.print(f"[bold]Answer:[/] {ans.answer}")
for c in ans.citations:
console.print(f" • {c.path}:{c.line_no} score={c.score:.3f}")
console.print(f"[dim]max_score={ans.max_score:.3f} threshold={ans.threshold}[/]")
if __name__ == "__main__":
app()
def main(): # console_script entry
app()