From 4823643e7ecfb324b40a60c9d9a92828b2fa653a Mon Sep 17 00:00:00 2001 From: weich97 <25754285+weich97@users.noreply.github.com> Date: Sat, 13 Jun 2026 19:32:36 +0800 Subject: [PATCH] Add temperature/repeated-sampling support and Wilson intervals - DeepSeekLLMAnalyst gains a configurable temperature (default 0.0, so existing deterministic runs are unchanged). - run_audit_eval.py accepts --samples-per-task and --temperature, keys the cache and checkpoint by sample (sample 0 keeps the legacy key so the deterministic main-table cache replays), enabling confidence intervals over auditor recall without disturbing the temp-0 results. - statistics.wilson_interval: binomial score interval, robust at the 0/1 extremes that the auditor recalls hit, with unit tests. --- scripts/run_audit_eval.py | 92 ++++++++++++++++--------- src/tradearena/agents/llm.py | 3 +- src/tradearena/evaluation/statistics.py | 18 +++++ tests/test_statistics.py | 16 +++++ 4 files changed, 95 insertions(+), 34 deletions(-) diff --git a/scripts/run_audit_eval.py b/scripts/run_audit_eval.py index 26b4921f..c39a396d 100644 --- a/scripts/run_audit_eval.py +++ b/scripts/run_audit_eval.py @@ -128,14 +128,25 @@ def parse_findings(response_text: str) -> list[dict[str, Any]]: return findings -def call_model(provider: str, model: str, prompt: str, cache_dir: Path, task_id: str) -> str: - """Chat-completions call with per-task caching, reusing the analyst adapter's transport.""" - - analyst = _make_analyst(provider, model, cache_dir) +def call_model( + provider: str, + model: str, + prompt: str, + cache_dir: Path, + task_id: str, + *, + sample: int = 0, + temperature: float = 0.0, +) -> str: + """Chat-completions call with per-(task, sample) caching, reusing the analyst transport.""" + + analyst = _make_analyst(provider, model, cache_dir, temperature=temperature) import hashlib prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest() - cache_key = f"audit:{provider}:{model}:{task_id}:{prompt_hash}" + # sample 0 keeps the legacy key so the deterministic main-table cache replays. + suffix = "" if sample == 0 else f":s{sample}" + cache_key = f"audit:{provider}:{model}:{task_id}:{prompt_hash}{suffix}" cache = analyst._cache() cached = cache.get(cache_key) if cached is not None: @@ -150,12 +161,13 @@ def call_model(provider: str, model: str, prompt: str, cache_dir: Path, task_id: "prompt": prompt, "response_text": response_text, "task_id": task_id, + "sample": sample, } ) return response_text -def _make_analyst(provider: str, model: str, cache_dir: Path) -> DeepSeekLLMAnalyst: +def _make_analyst(provider: str, model: str, cache_dir: Path, *, temperature: float = 0.0) -> DeepSeekLLMAnalyst: slug = f"{provider}_{model}".replace("-", "_").replace(".", "_").replace(":", "_") cache_path = str(cache_dir / f"audit_{slug}.jsonl") if provider == "poe": @@ -170,6 +182,7 @@ def _make_analyst(provider: str, model: str, cache_dir: Path) -> DeepSeekLLMAnal thinking="", use_response_format=False, timeout_seconds=180, + temperature=temperature, ) if provider == "glm": return DeepSeekLLMAnalyst( @@ -183,8 +196,11 @@ def _make_analyst(provider: str, model: str, cache_dir: Path) -> DeepSeekLLMAnal thinking="disabled", use_response_format=False, timeout_seconds=180, + temperature=temperature, ) - return DeepSeekLLMAnalyst(model=model, cache_path=cache_path, provider="deepseek", timeout_seconds=180) + return DeepSeekLLMAnalyst( + model=model, cache_path=cache_path, provider="deepseek", timeout_seconds=180, temperature=temperature + ) def main(argv: list[str] | None = None) -> int: @@ -192,6 +208,8 @@ def main(argv: list[str] | None = None) -> int: parser.add_argument("--tasks-dir", default="outputs/audit_tasks") parser.add_argument("--models", default="deepseek:deepseek-v4-pro", help="Comma-separated provider:model entries.") parser.add_argument("--max-tasks", type=int, default=0, help="Limit task count (0 = all).") + parser.add_argument("--samples-per-task", type=int, default=1, help="Repeated auditor samples per task (sample 0 is the deterministic main pass).") + parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature; use >0 with --samples-per-task for CI estimation.") parser.add_argument("--cache-dir", default="outputs/llm_cache/audit_eval") parser.add_argument("--output-dir", default="outputs/audit_eval") args = parser.parse_args(argv) @@ -212,46 +230,54 @@ def main(argv: list[str] | None = None) -> int: task_dirs = task_dirs[: args.max_tasks] results_path = output_dir / "audit_eval_results.jsonl" - done: set[tuple[str, str]] = set() + done: set[tuple[str, str, int]] = set() if results_path.exists(): with results_path.open(encoding="utf-8") as handle: for line in handle: row = json.loads(line) - done.add((row["model"], row["task_id"])) + done.add((row["model"], row["task_id"], int(row.get("sample", 0)))) if done: - print(f"Resuming: {len(done)} (model, task) results already checkpointed", flush=True) + print(f"Resuming: {len(done)} (model, task, sample) results already checkpointed", flush=True) + samples = max(1, int(args.samples_per_task)) models = [item.strip() for item in args.models.split(",") if item.strip()] with results_path.open("a", encoding="utf-8") as results_handle: for spec in models: provider, model = spec.split(":", 1) for task_dir in task_dirs: task_id = task_dir.name - if (spec, task_id) in done: - continue truth = truth_by_task[task_id] prompt = build_prompt(task_dir) - try: - response = call_model(provider, model, prompt, cache_dir, task_id) - except Exception as exc: # provider failures should not lose the run - print(f"FAILED {spec} {task_id}: {type(exc).__name__}", file=sys.stderr, flush=True) - continue - findings = parse_findings(response) - scores = score_findings(findings, [truth]) - record = { - "model": spec, - "task_id": task_id, - "kind": truth["kind"], - "difficulty": truth["difficulty"], - "findings": findings, - "parsed": bool(findings) or "[]" in response, - **{k: scores[k] for k in ("findings", "true_positives", "precision", "recall", "f1") if k in scores}, - } - record["finding_count"] = record.pop("findings") if isinstance(record.get("findings"), int) else len(findings) - record["findings"] = findings - results_handle.write(json.dumps(record, sort_keys=True) + "\n") - results_handle.flush() - print(f"OK {spec} {task_id} recall={scores['recall']:.0%} findings={len(findings)}", flush=True) + for sample in range(samples): + if (spec, task_id, sample) in done: + continue + try: + response = call_model( + provider, model, prompt, cache_dir, task_id, + sample=sample, temperature=args.temperature, + ) + except Exception as exc: # provider failures should not lose the run + print(f"FAILED {spec} {task_id} s{sample}: {type(exc).__name__}", file=sys.stderr, flush=True) + continue + findings = parse_findings(response) + scores = score_findings(findings, [truth]) + record = { + "model": spec, + "task_id": task_id, + "sample": sample, + "kind": truth["kind"], + "difficulty": truth["difficulty"], + "parsed": bool(findings) or "[]" in response, + "true_positives": scores["true_positives"], + "precision": scores["precision"], + "recall": scores["recall"], + "f1": scores["f1"], + "finding_count": len(findings), + "findings": findings, + } + results_handle.write(json.dumps(record, sort_keys=True) + "\n") + results_handle.flush() + print(f"OK {spec} {task_id} s{sample} recall={scores['recall']:.0%}", flush=True) _write_summary(results_path, output_dir / "audit_eval_summary.csv") return 0 diff --git a/src/tradearena/agents/llm.py b/src/tradearena/agents/llm.py index 9b985b15..2165c346 100644 --- a/src/tradearena/agents/llm.py +++ b/src/tradearena/agents/llm.py @@ -37,6 +37,7 @@ class DeepSeekLLMAnalyst: mask_timestamps: bool = False anonymize_symbols: bool = False sample_index: int = 0 + temperature: float = 0.0 name: str = "deepseek-llm-analyst" _cache_entries: dict[str, dict[str, Any]] | None = field(default=None, init=False, repr=False) _cache_mtime_ns: int | None = field(default=None, init=False, repr=False) @@ -222,7 +223,7 @@ def _call_deepseek(self, prompt: str) -> str: ) request_body = { "model": self.api_model or self.model, - "temperature": 0, + "temperature": float(self.temperature), "messages": [ { "role": "system", diff --git a/src/tradearena/evaluation/statistics.py b/src/tradearena/evaluation/statistics.py index 01a28bab..70f35420 100644 --- a/src/tradearena/evaluation/statistics.py +++ b/src/tradearena/evaluation/statistics.py @@ -7,6 +7,24 @@ from typing import Any +def wilson_interval(successes: int, trials: int, *, z: float = 1.96) -> tuple[float, float, float]: + """Wilson score interval for a binomial proportion. + + Returns (point, low, high). Robust at the extremes (0 or 1 successes) + where the normal approximation degenerates, which is the regime of the + near-ceiling and near-floor auditor recalls reported here. + """ + + if trials <= 0: + return 0.0, 0.0, 0.0 + p = successes / trials + z2 = z * z + denom = 1.0 + z2 / trials + center = (p + z2 / (2 * trials)) / denom + margin = (z * math.sqrt(p * (1 - p) / trials + z2 / (4 * trials * trials))) / denom + return p, max(0.0, center - margin), min(1.0, center + margin) + + def mean(values: Iterable[float]) -> float: numbers = [float(value) for value in values] return sum(numbers) / len(numbers) if numbers else 0.0 diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 36d3ebe2..ff876da8 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -182,3 +182,19 @@ def test_paired_bootstrap_difference_reports_effect_sizes(): empty = paired_bootstrap_difference({}, {}) assert empty["cohens_d"] is None assert empty["cliffs_delta"] is None + + +def test_wilson_interval_bounds_and_extremes(): + from tradearena.evaluation.statistics import wilson_interval + + p, lo, hi = wilson_interval(8, 10) + assert p == 0.8 + assert 0.0 < lo < 0.8 < hi < 1.0 + # Perfect success: point 1.0, upper capped at 1.0, lower below 1. + p1, lo1, hi1 = wilson_interval(25, 25) + assert p1 == 1.0 and hi1 == 1.0 and lo1 < 1.0 + # Zero successes: point 0, lower floored at 0. + p0, lo0, hi0 = wilson_interval(0, 25) + assert p0 == 0.0 and lo0 == 0.0 and hi0 > 0.0 + # Empty. + assert wilson_interval(0, 0) == (0.0, 0.0, 0.0)