Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 59 additions & 33 deletions scripts/run_audit_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,25 @@ def parse_findings(response_text: str) -> list[dict[str, Any]]:
return findings


def call_model(provider: str, model: str, prompt: str, cache_dir: Path, task_id: str) -> str:
"""Chat-completions call with per-task caching, reusing the analyst adapter's transport."""

analyst = _make_analyst(provider, model, cache_dir)
def call_model(
provider: str,
model: str,
prompt: str,
cache_dir: Path,
task_id: str,
*,
sample: int = 0,
temperature: float = 0.0,
) -> str:
"""Chat-completions call with per-(task, sample) caching, reusing the analyst transport."""

analyst = _make_analyst(provider, model, cache_dir, temperature=temperature)
import hashlib

prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()
cache_key = f"audit:{provider}:{model}:{task_id}:{prompt_hash}"
# sample 0 keeps the legacy key so the deterministic main-table cache replays.
suffix = "" if sample == 0 else f":s{sample}"
cache_key = f"audit:{provider}:{model}:{task_id}:{prompt_hash}{suffix}"
cache = analyst._cache()
cached = cache.get(cache_key)
if cached is not None:
Expand All @@ -150,12 +161,13 @@ def call_model(provider: str, model: str, prompt: str, cache_dir: Path, task_id:
"prompt": prompt,
"response_text": response_text,
"task_id": task_id,
"sample": sample,
}
)
return response_text


def _make_analyst(provider: str, model: str, cache_dir: Path) -> DeepSeekLLMAnalyst:
def _make_analyst(provider: str, model: str, cache_dir: Path, *, temperature: float = 0.0) -> DeepSeekLLMAnalyst:
slug = f"{provider}_{model}".replace("-", "_").replace(".", "_").replace(":", "_")
cache_path = str(cache_dir / f"audit_{slug}.jsonl")
if provider == "poe":
Expand All @@ -170,6 +182,7 @@ def _make_analyst(provider: str, model: str, cache_dir: Path) -> DeepSeekLLMAnal
thinking="",
use_response_format=False,
timeout_seconds=180,
temperature=temperature,
)
if provider == "glm":
return DeepSeekLLMAnalyst(
Expand All @@ -183,15 +196,20 @@ def _make_analyst(provider: str, model: str, cache_dir: Path) -> DeepSeekLLMAnal
thinking="disabled",
use_response_format=False,
timeout_seconds=180,
temperature=temperature,
)
return DeepSeekLLMAnalyst(model=model, cache_path=cache_path, provider="deepseek", timeout_seconds=180)
return DeepSeekLLMAnalyst(
model=model, cache_path=cache_path, provider="deepseek", timeout_seconds=180, temperature=temperature
)


def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Evaluate LLM auditors on defect-injected trajectories.")
parser.add_argument("--tasks-dir", default="outputs/audit_tasks")
parser.add_argument("--models", default="deepseek:deepseek-v4-pro", help="Comma-separated provider:model entries.")
parser.add_argument("--max-tasks", type=int, default=0, help="Limit task count (0 = all).")
parser.add_argument("--samples-per-task", type=int, default=1, help="Repeated auditor samples per task (sample 0 is the deterministic main pass).")
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature; use >0 with --samples-per-task for CI estimation.")
parser.add_argument("--cache-dir", default="outputs/llm_cache/audit_eval")
parser.add_argument("--output-dir", default="outputs/audit_eval")
args = parser.parse_args(argv)
Expand All @@ -212,46 +230,54 @@ def main(argv: list[str] | None = None) -> int:
task_dirs = task_dirs[: args.max_tasks]

results_path = output_dir / "audit_eval_results.jsonl"
done: set[tuple[str, str]] = set()
done: set[tuple[str, str, int]] = set()
if results_path.exists():
with results_path.open(encoding="utf-8") as handle:
for line in handle:
row = json.loads(line)
done.add((row["model"], row["task_id"]))
done.add((row["model"], row["task_id"], int(row.get("sample", 0))))
if done:
print(f"Resuming: {len(done)} (model, task) results already checkpointed", flush=True)
print(f"Resuming: {len(done)} (model, task, sample) results already checkpointed", flush=True)

samples = max(1, int(args.samples_per_task))
models = [item.strip() for item in args.models.split(",") if item.strip()]
with results_path.open("a", encoding="utf-8") as results_handle:
for spec in models:
provider, model = spec.split(":", 1)
for task_dir in task_dirs:
task_id = task_dir.name
if (spec, task_id) in done:
continue
truth = truth_by_task[task_id]
prompt = build_prompt(task_dir)
try:
response = call_model(provider, model, prompt, cache_dir, task_id)
except Exception as exc: # provider failures should not lose the run
print(f"FAILED {spec} {task_id}: {type(exc).__name__}", file=sys.stderr, flush=True)
continue
findings = parse_findings(response)
scores = score_findings(findings, [truth])
record = {
"model": spec,
"task_id": task_id,
"kind": truth["kind"],
"difficulty": truth["difficulty"],
"findings": findings,
"parsed": bool(findings) or "[]" in response,
**{k: scores[k] for k in ("findings", "true_positives", "precision", "recall", "f1") if k in scores},
}
record["finding_count"] = record.pop("findings") if isinstance(record.get("findings"), int) else len(findings)
record["findings"] = findings
results_handle.write(json.dumps(record, sort_keys=True) + "\n")
results_handle.flush()
print(f"OK {spec} {task_id} recall={scores['recall']:.0%} findings={len(findings)}", flush=True)
for sample in range(samples):
if (spec, task_id, sample) in done:
continue
try:
response = call_model(
provider, model, prompt, cache_dir, task_id,
sample=sample, temperature=args.temperature,
)
except Exception as exc: # provider failures should not lose the run
print(f"FAILED {spec} {task_id} s{sample}: {type(exc).__name__}", file=sys.stderr, flush=True)
continue
findings = parse_findings(response)
scores = score_findings(findings, [truth])
record = {
"model": spec,
"task_id": task_id,
"sample": sample,
"kind": truth["kind"],
"difficulty": truth["difficulty"],
"parsed": bool(findings) or "[]" in response,
"true_positives": scores["true_positives"],
"precision": scores["precision"],
"recall": scores["recall"],
"f1": scores["f1"],
"finding_count": len(findings),
"findings": findings,
}
results_handle.write(json.dumps(record, sort_keys=True) + "\n")
results_handle.flush()
print(f"OK {spec} {task_id} s{sample} recall={scores['recall']:.0%}", flush=True)

_write_summary(results_path, output_dir / "audit_eval_summary.csv")
return 0
Expand Down
3 changes: 2 additions & 1 deletion src/tradearena/agents/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DeepSeekLLMAnalyst:
mask_timestamps: bool = False
anonymize_symbols: bool = False
sample_index: int = 0
temperature: float = 0.0
name: str = "deepseek-llm-analyst"
_cache_entries: dict[str, dict[str, Any]] | None = field(default=None, init=False, repr=False)
_cache_mtime_ns: int | None = field(default=None, init=False, repr=False)
Expand Down Expand Up @@ -222,7 +223,7 @@ def _call_deepseek(self, prompt: str) -> str:
)
request_body = {
"model": self.api_model or self.model,
"temperature": 0,
"temperature": float(self.temperature),
"messages": [
{
"role": "system",
Expand Down
18 changes: 18 additions & 0 deletions src/tradearena/evaluation/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@
from typing import Any


def wilson_interval(successes: int, trials: int, *, z: float = 1.96) -> tuple[float, float, float]:
"""Wilson score interval for a binomial proportion.

Returns (point, low, high). Robust at the extremes (0 or 1 successes)
where the normal approximation degenerates, which is the regime of the
near-ceiling and near-floor auditor recalls reported here.
"""

if trials <= 0:
return 0.0, 0.0, 0.0
p = successes / trials
z2 = z * z
denom = 1.0 + z2 / trials
center = (p + z2 / (2 * trials)) / denom
margin = (z * math.sqrt(p * (1 - p) / trials + z2 / (4 * trials * trials))) / denom
return p, max(0.0, center - margin), min(1.0, center + margin)


def mean(values: Iterable[float]) -> float:
numbers = [float(value) for value in values]
return sum(numbers) / len(numbers) if numbers else 0.0
Expand Down
16 changes: 16 additions & 0 deletions tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,19 @@ def test_paired_bootstrap_difference_reports_effect_sizes():
empty = paired_bootstrap_difference({}, {})
assert empty["cohens_d"] is None
assert empty["cliffs_delta"] is None


def test_wilson_interval_bounds_and_extremes():
from tradearena.evaluation.statistics import wilson_interval

p, lo, hi = wilson_interval(8, 10)
assert p == 0.8
assert 0.0 < lo < 0.8 < hi < 1.0
# Perfect success: point 1.0, upper capped at 1.0, lower below 1.
p1, lo1, hi1 = wilson_interval(25, 25)
assert p1 == 1.0 and hi1 == 1.0 and lo1 < 1.0
# Zero successes: point 0, lower floored at 0.
p0, lo0, hi0 = wilson_interval(0, 25)
assert p0 == 0.0 and lo0 == 0.0 and hi0 > 0.0
# Empty.
assert wilson_interval(0, 0) == (0.0, 0.0, 0.0)