diff --git a/.agents/rules/amendments.md b/.agents/rules/amendments.md index 6552644..72890f0 100644 --- a/.agents/rules/amendments.md +++ b/.agents/rules/amendments.md @@ -35,3 +35,37 @@ Sessionfriction identified during prune-repo + cleanup work: - **Logging config** (`vis/logging_config.py` is canonical; never configure in `page.py`) - **Schema field removal checklist** (grep callers, confirm never populated, no `list[dict]` placeholders) - `python.md` sync-guard test bullet was vague. Expanded with the concrete pattern: compare `model_fields` against the reactive registry at test time. + +## 2026-03-12 — Always use vis/plotting.py for paper figures + +`vis/plotting.py` is the single source of truth for all chart functions. When generating +figures for papers, scripts, or exports, **always call functions from there** — never write +custom matplotlib from scratch in agent_workspace scripts. + +Available functions to reach for first: +- `plot_sweep_curve(SweepResult)` — 1D sweep line chart with tipping point annotation +- `plot_mc_trajectory(MonteCarloResult)` — compliance over steps, mean ± SD +- `plot_mc_violator_trajectory`, `plot_mc_audit_trajectory`, `plot_mc_payoff_comparison` + +If a needed figure type does not exist in `vis/plotting.py` (e.g. a 2D heatmap), **add it +there** following the `create_figure()` style, then use it from both the UI and scripts. +Do not create ad-hoc matplotlib code in agent_workspace when an equivalent function +already exists or could be added once and shared. + +## 2026-03-14 — Reflect: plotting discipline and scripting infrastructure + +Three friction sources identified, all patched this session: + +1. **`project.md` Plots section was too sparse** — 2 lines with no function inventory. + Replaced with the full table of all 12 public functions and a mandatory "check before + writing any matplotlib" gate. The agent cannot now claim ignorance of what exists. + +2. **`researcher.md` step 4 had zero mention of `vis/plotting.py`** — meaning every + research visualisation session was allowed to invent ad-hoc matplotlib. Added an + `[!IMPORTANT]` callout before the visualise step enforcing the same gate. + +3. **No `/gen-figures` workflow existed** — figure generation for paper sections was + improvised each time. Created `.agents/workflows/gen-figures.md` with a step-by-step + thin-caller checklist, a copy-paste script template, and a `// turbo` run step. + Also created `scripts/README.md` as the cross-session script index so existing + scripts are discoverable rather than silently re-invented. diff --git a/.agents/rules/project.md b/.agents/rules/project.md index a7cc495..63ea6bf 100644 --- a/.agents/rules/project.md +++ b/.agents/rules/project.md @@ -100,7 +100,28 @@ All export functions return `bytes` for Solara's `FileDownload`. Key functions: ## Plots (`vis/plotting.py`) -Accept typed result objects, return `matplotlib.Figure`, never import Solara. Use `fig_to_png(fig)` from `results.py` to convert to bytes for downloads. Standard figsize `(7, 4)`. +**Before writing any matplotlib code**, check this inventory. If the function you need exists here, call it. If it doesn't exist, add it here following the `create_figure()` style — then use it from both scripts and the UI. + +All functions accept typed result objects, return `matplotlib.Figure`, never import Solara. Use `fig_to_png(fig)` from `results.py` to convert to bytes for downloads. + +| Function | Input | Use for | +|---|---|---| +| `plot_sweep_curve(result, metric, reference_lines)` | `SweepResult` | 1D sweep line chart with tipping point + optional scenario markers | +| `plot_sweep_heatmap(grid, x_values, y_values, ...)` | 2D `list[list[float]]` | 2D compliance heatmap (joint sensitivity) | +| `plot_mc_trajectory(result)` | `MonteCarloResult` | Compliance mean ± SD over steps | +| `plot_mc_violator_trajectory(result)` | `MonteCarloResult` | Violator count mean ± SD over steps | +| `plot_mc_audit_trajectory(result)` | `MonteCarloResult` | Audit rate band over steps | +| `plot_mc_payoff_comparison(result)` | `MonteCarloResult` | Compliant vs. violating lab payoff bar chart | +| `plot_compliance_distribution(df)` | agents DataFrame | Bar chart: Compliant / Uncaught / Caught-by-source | +| `plot_audit_source_distribution(df)` | agents DataFrame | Bar chart: labs caught per AuditSource channel | +| `plot_audit_targeting(rates, counts, ...)` | scalar rates | Compliant vs. non-compliant audit rate bar | +| `plot_audit_coefficient_distribution(df)` | agents DataFrame | Histogram of per-lab audit coefficients | +| `plot_time_series(data, label, color_key)` | `pd.Series` | Generic single-series step chart | +| `plot_scatter(df, x_col, y_col, ...)` | DataFrame | Scatter with compliance coloring | + +All figures are created via `create_figure()` (standardized style, `Agg` backend). Never call `plt.figure()` or `plt.subplots()` in scripts. + +**Committed figure scripts** — see `scripts/README.md` for an index of existing scripts. Always check there before re-creating a script. ## Testing diff --git a/.agents/workflows/gen-figures.md b/.agents/workflows/gen-figures.md new file mode 100644 index 0000000..c52733d --- /dev/null +++ b/.agents/workflows/gen-figures.md @@ -0,0 +1,112 @@ +--- +description: Generate one or more figures for the paper or a report — enforces the thin-caller pattern where all plot logic lives in vis/plotting.py. +--- + +# Gen-Figures Workflow + +Use this workflow whenever you need to produce `.png` figures for the paper, a report, +or any committed output. Do **not** improvise — follow these steps in order. + +## Step 1 — Check `vis/plotting.py` first + +Open `project.md` and read the **Plots** section inventory table. +Find the function that matches the figure you need. + +- **Exists?** → go to Step 3. +- **Doesn't exist?** → you must add it to `vis/plotting.py` first (Step 2), then proceed. + +**Never write raw `plt.figure()` or `plt.subplots()` in a script or agent_workspace file.** +Use `create_figure()` from `vis/plotting.py` at minimum, and prefer a proper named function. + +## Step 2 — Add a missing function to `vis/plotting.py` (if needed) + +1. Follow the `create_figure()` style exactly — see existing functions for the pattern. +2. Accept typed result objects (`SweepResult`, `MonteCarloResult`, `pd.DataFrame`) — no raw dicts. +3. Return `matplotlib.Figure` (never call `plt.show()` or `plt.savefig()` inside the function). +4. Add it to the inventory table in `project.md` → Plots section. +5. Run `uv run ruff check . --fix && uv run mypy .` — fix any issues before proceeding. + +## Step 3 — Check `scripts/README.md` for an existing script + +Open `scripts/README.md`. +If a script already generates the figures you need (or close to it), **run that script** rather than writing a new one. + +```bash +uv run python scripts/.py --out-dir agent_workspace/figures +``` + +If the existing script's parameters or scenarios need adjustment, edit it in place — don't create a duplicate. + +## Step 4 — Write a thin-caller script (if no existing script covers it) + +Create a new script in `scripts/` following the naming convention `gen__figs.py`. + +The script must follow the **thin-caller pattern**: +- All imports from `vis.plotting`, `services.*`, `schemas.*` +- No matplotlib setup — no `plt.figure()`, `plt.subplots()`, `matplotlib.use()` +- Each figure: call the `vis/plotting.py` function → `fig.savefig(out_dir / "name.png", dpi=150, bbox_inches="tight")` +- Accept `--out-dir` as a CLI argument (default: `agent_workspace/figures`) +- Print progress lines so it's easy to monitor + +Minimal template: +```python +"""Generate figures for the paper. + +Thin caller only — all plot logic lives in vis/plotting.py. +Output: agent_workspace/figures/.png + +Usage: + uv run python scripts/gen__figs.py [--out-dir PATH] +""" +from __future__ import annotations +import argparse +from pathlib import Path + + +def main(out_dir: Path) -> None: + out_dir.mkdir(parents=True, exist_ok=True) + + from compute_permit_sim.services.config_manager import load_scenario + from compute_permit_sim.services.sweep import run_sweep + from compute_permit_sim.vis.plotting import plot_sweep_curve # add as needed + + base = load_scenario("basic/.json") + result = run_sweep(base, "audit.base_prob", [...], n_runs=50) + fig = plot_sweep_curve(result) + fig.savefig(out_dir / "fig_.png", dpi=150, bbox_inches="tight") + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=Path, default=Path("agent_workspace/figures")) + args = parser.parse_args() + main(args.out_dir) +``` + +## Step 5 — Update `scripts/README.md` + +After writing or modifying a script, update the index in `scripts/README.md`: + +``` +| gen__figs.py | Generates for Section X. Scenarios: <...>. | +``` + +## Step 6 — Run and verify + +// turbo +```bash +uv run python scripts/.py --out-dir agent_workspace/figures +``` + +Check that: +- All expected `.png` files are created in `out_dir` +- No matplotlib warnings or errors in output +- Figures look correct (open them and inspect) + +## Step 7 — Commit the script + +```bash +git add scripts/.py scripts/README.md +git commit -m "scripts: add figure generator" +``` diff --git a/pyproject.toml b/pyproject.toml index 01d9a80..742d309 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,3 +51,4 @@ python_files = ["test_*.py"] ignore_missing_imports = true check_untyped_defs = true plugins = ["pydantic.mypy"] +exclude = ["agent_workspace"] diff --git a/scenarios/basic/scenario_2_strict.json b/scenarios/basic/scenario_2_strict.json index 9375a1f..d2bb04d 100644 --- a/scenarios/basic/scenario_2_strict.json +++ b/scenarios/basic/scenario_2_strict.json @@ -12,9 +12,10 @@ }, "lab": { "capability_value": 40.0, - "racing_factor": 2.0 + "racing_factor": 2.0, + "audit_coefficient": 0.1 }, - "collateral_amount": 100.0, + "collateral_amount": 15.75, "market": { "fixed_price": 70.0 } diff --git a/scenarios/basic/scenario_3_smart.json b/scenarios/basic/scenario_3_smart.json index 12617d3..797c393 100644 --- a/scenarios/basic/scenario_3_smart.json +++ b/scenarios/basic/scenario_3_smart.json @@ -4,10 +4,13 @@ "steps": 10, "n_agents": 20, "audit": { - "base_prob": 0.2, + "base_prob": 0.1, "monitoring_prob": 0.2, "signal_dependent": true }, + "lab": { + "audit_coefficient": 0.5 + }, "collateral_amount": 15.75, "market": { "fixed_price": 2.0, diff --git a/scenarios/batch_test.json b/scenarios/batch_test.json index f303c95..31270c1 100644 --- a/scenarios/batch_test.json +++ b/scenarios/batch_test.json @@ -3,21 +3,21 @@ "description": "Demonstrates that feedback mechanisms (reputation, audit escalation) can drive compliance even under moderate enforcement.", "notes": "", "n_agents": 20, - "steps": 40, + "steps": 50, "flop_threshold": 1e25, "collateral_amount": 0.0, "audit": { - "base_prob": 0.3, + "base_prob": 0.20, "signal_dependent": false, "signal_exponent": 1.0, "false_positive_rate": 0.0, "false_negative_rate": 0.05, - "penalty_amount": 100.0, + "penalty_amount": 50.0, "backcheck_prob": 0.0, "whistleblower_prob": 0.0, "monitoring_prob": 0.0, "max_audits_per_step": null, - "audit_escalation": 1.5, + "audit_escalation": 0.5, "audit_decay_rate": 0.1 }, "market": { diff --git a/src/compute_permit_sim/schemas/batch.py b/src/compute_permit_sim/schemas/batch.py index 7f13e4e..f7afbde 100644 --- a/src/compute_permit_sim/schemas/batch.py +++ b/src/compute_permit_sim/schemas/batch.py @@ -28,6 +28,10 @@ class BatchColumnNames: STEP = "step" PARAM_PATH = "param_path" PARAM_VALUE = "param_value" + PARAM_X_PATH = "param_x_path" + PARAM_X_VALUE = "param_x_value" + PARAM_Y_PATH = "param_y_path" + PARAM_Y_VALUE = "param_y_value" N_RUNS = "n_runs" # Compliance @@ -215,3 +219,48 @@ def tipping_point(self, threshold: float = 0.95) -> float | None: if pt.result.avg_compliance.mean >= threshold: return pt.param_value return None + + +@dataclass(frozen=True) +class GridSweepResult: + """Results of a 2D joint-sensitivity parameter sweep over a scenario. + + Stores mean compliance at every (x, y) grid cell. + + Attributes: + grid: ``grid[y_idx][x_idx]`` = mean compliance fraction (0–1) + over ``n_runs`` seeds at parameter values + ``(x_values[x_idx], y_values[y_idx])``. + """ + + scenario_name: str + param_x_path: str # e.g. "audit.base_prob" + param_x_label: str # human-readable, e.g. "Base Audit Probability" + param_y_path: str # e.g. "collateral_amount" + param_y_label: str # human-readable, e.g. "Collateral K (M$)" + config: ScenarioConfig + x_values: list[float] # ordered x-axis values + y_values: list[float] # ordered y-axis values + grid: list[list[float]] # [y_idx][x_idx] = mean compliance in [0, 1] + n_runs: int + # Short unique identifier matching SimulationRun.sim_id convention + id: str = field(default_factory=lambda: str(uuid4())[:8]) + + def compliance_at(self, x: float, y: float) -> float | None: + """Return mean compliance for an exact (x, y) cell, or None if not found.""" + try: + x_idx = self.x_values.index(x) + y_idx = self.y_values.index(y) + except ValueError: + return None + return self.grid[y_idx][x_idx] + + @property + def compliance_min(self) -> float: + """Minimum mean compliance across all grid cells.""" + return min(v for row in self.grid for v in row) + + @property + def compliance_max(self) -> float: + """Maximum mean compliance across all grid cells.""" + return max(v for row in self.grid for v in row) diff --git a/src/compute_permit_sim/schemas/defaults.py b/src/compute_permit_sim/schemas/defaults.py index d6de624..4c2aba2 100644 --- a/src/compute_permit_sim/schemas/defaults.py +++ b/src/compute_permit_sim/schemas/defaults.py @@ -77,7 +77,7 @@ DEFAULT_SIGNAL_EXPONENT = 1.0 # # Stage 2: AUDIT OUTCOME — given audit, does it find a violation? -# p_catch_if_audited = (1 - FNR) + FNR × backcheck_prob +# p_catch_if_audited = 1 - FNR × (1 - backcheck_prob) × (1 - p_w) × (1 - p_m) DEFAULT_AUDIT_FALSE_POS_RATE = 0.0 # alpha: P(false alarm | compliant firm audited) DEFAULT_AUDIT_FALSE_NEG_RATE = 0.40 # beta: 40% miss rate in Minimal env # Penalty structure: diff --git a/src/compute_permit_sim/schemas/sweep_params.py b/src/compute_permit_sim/schemas/sweep_params.py index ec59138..d5fa652 100644 --- a/src/compute_permit_sim/schemas/sweep_params.py +++ b/src/compute_permit_sim/schemas/sweep_params.py @@ -170,6 +170,46 @@ class SweepParam: description="Upper bound of risk appetite multiplier (>1 = risk-seeking).", category="Agents", ), + SweepParam( + path="lab.capability_value", + label="Capability Race Premium V_b", + unit="M$", + default_min=0.0, + default_max=300.0, + default_step=20.0, + description="Strategic value of model capabilities from training (arms-race premium added to gain from cheating).", + category="Agents", + ), + SweepParam( + path="lab.racing_factor", + label="Racing Factor c_r", + unit="", + default_min=0.0, + default_max=5.0, + default_step=0.25, + description="Urgency multiplier on capability value; higher = stronger competitive pressure to cheat.", + category="Agents", + ), + SweepParam( + path="lab.reputation_escalation_factor", + label="Reputation Escalation Factor", + unit="", + default_min=0.0, + default_max=5.0, + default_step=0.25, + description="Per-violation multiplier on reputation cost: rep_t = base × (1+factor)^n_caught. 0 = no escalation.", + category="Agents", + ), + SweepParam( + path="lab.reputation_sensitivity", + label="Reputation Sensitivity R", + unit="M$", + default_min=0.0, + default_max=100.0, + default_step=5.0, + description="Base reputation cost per violation (M$). Compounds with reputation_escalation_factor.", + category="Agents", + ), # --- Dynamics --- SweepParam( path="audit.signal_exponent", diff --git a/src/compute_permit_sim/services/sweep.py b/src/compute_permit_sim/services/sweep.py index 5d558f3..7d4ef07 100644 --- a/src/compute_permit_sim/services/sweep.py +++ b/src/compute_permit_sim/services/sweep.py @@ -18,7 +18,7 @@ from __future__ import annotations -from compute_permit_sim.schemas.batch import SweepPoint, SweepResult +from compute_permit_sim.schemas.batch import GridSweepResult, SweepPoint, SweepResult from compute_permit_sim.schemas.config import ScenarioConfig from compute_permit_sim.services.monte_carlo import run_monte_carlo @@ -143,3 +143,118 @@ def run_sweep_from_registry( n_runs=n_runs, seeds=seeds, ) + + +def run_grid_sweep( + base_config: ScenarioConfig, + param_x_path: str, + param_y_path: str, + x_values: list[float], + y_values: list[float], + param_x_label: str | None = None, + param_y_label: str | None = None, + n_runs: int = 20, + seeds: list[int] | None = None, +) -> GridSweepResult: + """Run a 2D joint-sensitivity sweep over two parameters. + + Each (x, y) cell is evaluated with ``n_runs`` Monte Carlo replications. + Results are stored as ``grid[y_idx][x_idx] = mean_compliance``. + + All seeds are shared across all cells so that parameter variation, not + noise, drives differences between cells. + + Args: + base_config: Base scenario configuration. + param_x_path: Dot-path for the x-axis parameter, e.g. ``"audit.base_prob"``. + param_y_path: Dot-path for the y-axis parameter, e.g. ``"collateral_amount"``. + x_values: Ordered x-axis values. + y_values: Ordered y-axis values. + param_x_label: Human-readable x-axis label; defaults to ``param_x_path``. + param_y_label: Human-readable y-axis label; defaults to ``param_y_path``. + n_runs: MC replications per cell. Ignored if ``seeds`` is provided. + seeds: Explicit seeds; overrides ``n_runs`` if given. + + Returns: + :class:`~compute_permit_sim.schemas.batch.GridSweepResult` with the 2D + compliance grid and axis metadata. + """ + label_x = param_x_label or param_x_path + label_y = param_y_label or param_y_path + run_seeds = seeds if seeds is not None else list(range(n_runs)) + + # grid[y_idx][x_idx] = mean compliance + grid: list[list[float]] = [] + for y in y_values: + row: list[float] = [] + for x in x_values: + cfg = override_config(base_config, param_x_path, x) + cfg = override_config(cfg, param_y_path, y) + mc = run_monte_carlo(cfg, seeds=run_seeds) + row.append(mc.avg_compliance.mean) + grid.append(row) + + return GridSweepResult( + scenario_name=base_config.name, + param_x_path=param_x_path, + param_x_label=label_x, + param_y_path=param_y_path, + param_y_label=label_y, + config=base_config, + x_values=list(x_values), + y_values=list(y_values), + grid=grid, + n_runs=len(run_seeds), + ) + + +def run_grid_sweep_from_registry( + base_config: ScenarioConfig, + param_x_path: str, + param_y_path: str, + x_min: float | None = None, + x_max: float | None = None, + x_step: float | None = None, + y_min: float | None = None, + y_max: float | None = None, + y_step: float | None = None, + n_runs: int = 20, + seeds: list[int] | None = None, +) -> GridSweepResult: + """Run a 2D grid sweep using registry defaults for both axis ranges. + + Looks up each path in ``SWEEPABLE_PARAMS`` to fill in default + min/max/step values. Any supplied arguments override those defaults. + + Args: + base_config: Base scenario configuration. + param_x_path: Dot-path registered in ``SWEEPABLE_PARAMS`` for x-axis. + param_y_path: Dot-path registered in ``SWEEPABLE_PARAMS`` for y-axis. + x_min/x_max/x_step: Override registry defaults for x-axis. + y_min/y_max/y_step: Override registry defaults for y-axis. + n_runs: MC replications per cell. + seeds: Explicit seeds (overrides n_runs if provided). + + Returns: + :class:`~compute_permit_sim.schemas.batch.GridSweepResult`. + + Raises: + KeyError: If either path is not in the registry. + """ + from compute_permit_sim.schemas.sweep_params import generate_values, get_param + + px = get_param(param_x_path) + py = get_param(param_y_path) + x_values = generate_values(px, min_val=x_min, max_val=x_max, step=x_step) + y_values = generate_values(py, min_val=y_min, max_val=y_max, step=y_step) + return run_grid_sweep( + base_config, + param_x_path=px.path, + param_y_path=py.path, + x_values=x_values, + y_values=y_values, + param_x_label=px.label, + param_y_label=py.label, + n_runs=n_runs, + seeds=seeds, + ) diff --git a/src/compute_permit_sim/vis/components/history.py b/src/compute_permit_sim/vis/components/history.py index f8b81c6..30c2fcb 100644 --- a/src/compute_permit_sim/vis/components/history.py +++ b/src/compute_permit_sim/vis/components/history.py @@ -27,12 +27,13 @@ def UnifiedHistoryList() -> None: (BatchHistoryList + RunHistoryList) pattern which caused double-nested ``run-history-compact`` for batch items and mismatched styling. """ - from compute_permit_sim.vis.state.run_state import mc_run, sweep_run + from compute_permit_sim.vis.state.run_state import grid_run, mc_run, sweep_run batch_results = session_history.batch_results.value run_history = session_history.run_history.value mc_current = mc_run.value.result sweep_current = sweep_run.value.result + grid_current = grid_run.value.result # Use Markdown for empty state — matches RunHistoryList convention and avoids # alternating root container types (Column A vs Column B) which reacton rejects. @@ -42,7 +43,11 @@ def UnifiedHistoryList() -> None: with solara.Column(classes=["run-history-compact"]): for result in batch_results: - is_current = (result is mc_current) or (result is sweep_current) + is_current = ( + (result is mc_current) + or (result is sweep_current) + or (result is grid_current) + ) BatchHistoryItem(result, is_current) for run in run_history: is_selected = (session_history.selected_run.value is not None) and ( diff --git a/src/compute_permit_sim/vis/components/history_items.py b/src/compute_permit_sim/vis/components/history_items.py index 2ff8eb4..b810677 100644 --- a/src/compute_permit_sim/vis/components/history_items.py +++ b/src/compute_permit_sim/vis/components/history_items.py @@ -11,7 +11,11 @@ import solara.lab from compute_permit_sim.schemas import SimulationRun -from compute_permit_sim.schemas.batch import MonteCarloResult, SweepResult +from compute_permit_sim.schemas.batch import ( + GridSweepResult, + MonteCarloResult, + SweepResult, +) from compute_permit_sim.services.config_manager import save_scenario from compute_permit_sim.vis.components.dialogs import RunConfigDialog from compute_permit_sim.vis.components.results import DownloadCSV, DownloadExcel @@ -143,12 +147,17 @@ def perform_save() -> None: @solara.component def BatchHistoryItem(result: BatchResult, is_current: bool) -> None: - """One-line history row for an MC or Sweep batch result. + """One-line history row for an MC, Sweep, or Grid Sweep batch result. - Mirrors ``RunHistoryItem`` exactly: type-icon | ⓘ | id-label | save | Excel | CSV | JSON. - The short ``result.id`` is displayed as the label; full details are in the ⓘ dialog. + Mirrors ``RunHistoryItem`` exactly: type-icon | \u24d8 | id-label | save | Excel | CSV | JSON. + The short ``result.id`` is displayed as the label; full details are in the \u24d8 dialog. """ - from compute_permit_sim.vis.state.run_state import RunState, mc_run, sweep_run + from compute_permit_sim.vis.state.run_state import ( + RunState, + grid_run, + mc_run, + sweep_run, + ) if isinstance(result, MonteCarloResult): type_icon = "mdi-chart-bell-curve-cumulative" @@ -181,7 +190,7 @@ def dl_excel() -> bytes | str: return export_monte_carlo_to_excel(result, output_path="") - else: # SweepResult + elif isinstance(result, SweepResult): type_icon = "mdi-trending-up" dialog_title = f"Sweep Run: {result.id}" tp = result.tipping_point() @@ -217,6 +226,42 @@ def dl_excel() -> bytes | str: return export_sweep_to_excel(result, output_path="") + else: # GridSweepResult + type_icon = "mdi-view-grid" + dialog_title = f"Grid Sweep: {result.id}" + safe_s = result.scenario_name.lower().replace(" ", "_") + safe_x = result.param_x_path.replace(".", "_") + safe_y = result.param_y_path.replace(".", "_") + batch_summary = ( + f"**{len(result.x_values)}×{len(result.y_values)} grid sweep**" + f" · {result.scenario_name} \n" + f"X: **{result.param_x_label}** \n" + f"Y: **{result.param_y_label}** \n" + f"Compliance: {result.compliance_min:.1%}–{result.compliance_max:.1%}" + ) + csv_fname = f"grid_{safe_s}_{safe_x}_x_{safe_y}_{result.id}.csv" + xlsx_fname = f"grid_{safe_s}_{safe_x}_x_{safe_y}_{result.id}.xlsx" + + def view() -> None: + session_history.selected_run.value = None # clear basic run highlight + mc_run.set(RunState[MonteCarloResult](phase="idle")) + sweep_run.set(RunState[SweepResult](phase="idle")) + grid_run.set(RunState[GridSweepResult](phase="ready", result=result)) + + def dl_csv() -> bytes | str: + from compute_permit_sim.vis.export import ( + export_grid_sweep_to_csv, # noqa: PLC0415 + ) + + return export_grid_sweep_to_csv(result, output_path="") + + def dl_excel() -> bytes | str: + from compute_permit_sim.vis.export import ( + export_grid_sweep_to_excel, # noqa: PLC0415 + ) + + return export_grid_sweep_to_excel(result, output_path="") + # use_state calls must be unconditional (Solara hook rules) — always before any return show_save, set_show_save = solara.use_state(False) save_name, set_save_name = solara.use_state(f"scenario_{result.id}") diff --git a/src/compute_permit_sim/vis/export.py b/src/compute_permit_sim/vis/export.py index 06d91a7..23b27a4 100644 --- a/src/compute_permit_sim/vis/export.py +++ b/src/compute_permit_sim/vis/export.py @@ -21,6 +21,10 @@ Sweep: export_sweep_to_excel — Excel with Config/Sweep/Graphs export_sweep_to_csv — one row per parameter value + +Grid Sweep (2D heatmap): + export_grid_sweep_to_csv — long-format CSV, one row per grid cell + export_grid_sweep_to_excel — Excel with Config/Grid pivot/Heatmap image """ import io @@ -35,12 +39,13 @@ BatchColumnNames as _BCN, ) from compute_permit_sim.schemas.batch import ( - MetricStats as _MetricStats, -) -from compute_permit_sim.schemas.batch import ( + GridSweepResult, MonteCarloResult, SweepResult, ) +from compute_permit_sim.schemas.batch import ( + MetricStats as _MetricStats, +) from compute_permit_sim.schemas.columns import ColumnNames from compute_permit_sim.services.metrics import calculate_compliance from compute_permit_sim.vis.plotting import ( @@ -1143,3 +1148,138 @@ def _write_sweep_graphs_sheet(sheet, result: SweepResult, workbook) -> None: sheet.write(0, 9, f"Audit Rate vs {param_label}") sheet.insert_image(1, 9, "sweep_audit.png", {"image_data": _fig_to_bytes(fig2)}) plt.close(fig2) + + +# ============================================================================= +# Grid Sweep (2D Heatmap) Exports +# ============================================================================= + + +def export_grid_sweep_to_csv( + result: GridSweepResult, + output_path: str | None = None, +) -> "str | bytes": + """Export a GridSweepResult to long-format CSV with one row per grid cell. + + Columns: scenario, param_x_path, param_x_value, param_y_path, + param_y_value, n_runs, compliance. + + Args: + result: A ``GridSweepResult`` instance. + output_path: ``None`` = auto-generate, ``""`` = return bytes. + """ + rows = [ + { + _BCN.SCENARIO: result.scenario_name, + _BCN.PARAM_X_PATH: result.param_x_path, + _BCN.PARAM_X_VALUE: x, + _BCN.PARAM_Y_PATH: result.param_y_path, + _BCN.PARAM_Y_VALUE: y, + _BCN.N_RUNS: result.n_runs, + _BCN.COMPLIANCE_RATE: result.grid[y_idx][x_idx], + } + for y_idx, y in enumerate(result.y_values) + for x_idx, x in enumerate(result.x_values) + ] + + df = _pd.DataFrame(rows) + if output_path == "": + return df.to_csv(index=False).encode("utf-8") + if output_path is None: + _os.makedirs("outputs", exist_ok=True) + safe_s = result.scenario_name.lower().replace(" ", "_") + safe_x = result.param_x_path.replace(".", "_") + safe_y = result.param_y_path.replace(".", "_") + output_path = f"outputs/grid_{safe_s}_{safe_x}_x_{safe_y}.csv" + df.to_csv(output_path, index=False) + return output_path + + +def export_grid_sweep_to_excel( + result: GridSweepResult, + output_path: str | None = None, +) -> "str | bytes": + """Export a GridSweepResult to a formatted Excel workbook. + + Sheets: + ``Config`` — base scenario configuration + ``Grid`` — pivot table: rows=y_values, cols=x_values, cells=compliance% + ``Heatmap`` — embedded PNG of the compliance heatmap + + Args: + result: A ``GridSweepResult`` instance. + output_path: ``None`` = auto-generate, ``""`` = return bytes. + """ + import io as _io + + import xlsxwriter as _xlsxwriter + + from compute_permit_sim.vis.plotting import plot_sweep_heatmap + + return_bytes = output_path == "" + output: _io.BytesIO | str + if return_bytes: + output = _io.BytesIO() + elif output_path is None: + _os.makedirs("outputs", exist_ok=True) + safe_s = result.scenario_name.lower().replace(" ", "_") + safe_x = result.param_x_path.replace(".", "_") + safe_y = result.param_y_path.replace(".", "_") + output_path = f"outputs/grid_{safe_s}_{safe_x}_x_{safe_y}.xlsx" + output = output_path + else: + output = output_path + + workbook = _xlsxwriter.Workbook(output) + header_fmt = workbook.add_format( + {"bold": True, "bg_color": "#2196F3", "font_color": "white", "border": 1} + ) + data_fmt = workbook.add_format({"border": 1}) + pct_fmt = workbook.add_format({"border": 1, "num_format": "0.0%"}) + + try: + # === Config sheet === + if result.config is not None: + cfg_sheet = workbook.add_worksheet("Config") + _write_config_sheet(cfg_sheet, result.config, header_fmt, data_fmt) + + # === Grid (pivot) sheet === + grid_sheet = workbook.add_worksheet("Grid") + grid_sheet.set_column("A:A", 20) + grid_sheet.write( + 0, + 0, + f"{result.param_x_label} \u2192 / {result.param_y_label} \u2193", + header_fmt, + ) + for x_idx, x in enumerate(result.x_values): + grid_sheet.write(0, x_idx + 1, x, header_fmt) + for y_idx, y in enumerate(result.y_values): + grid_sheet.write(y_idx + 1, 0, y, data_fmt) + for x_idx, compliance in enumerate(result.grid[y_idx]): + grid_sheet.write(y_idx + 1, x_idx + 1, compliance, pct_fmt) + + # === Heatmap sheet === + heatmap_sheet = workbook.add_worksheet("Heatmap") + fig = plot_sweep_heatmap( + compliance_grid=result.grid, + x_values=result.x_values, + y_values=result.y_values, + x_param_label=result.param_x_label, + y_param_label=result.param_y_label, + title=f"Compliance Heatmap \u2014 {result.scenario_name}", + ) + heatmap_sheet.insert_image( + 0, 0, "heatmap.png", {"image_data": _fig_to_bytes(fig)} + ) + + finally: + workbook.close() + + if return_bytes: + assert isinstance(output, _io.BytesIO) + output.seek(0) + return output.read() + + assert output_path is not None + return output_path diff --git a/src/compute_permit_sim/vis/page.py b/src/compute_permit_sim/vis/page.py index f01dc29..92a53e5 100644 --- a/src/compute_permit_sim/vis/page.py +++ b/src/compute_permit_sim/vis/page.py @@ -4,8 +4,10 @@ basic_run.phase == "running" → RunSpinner (basic sim) mc_run.phase == "running" → RunSpinner (Monte Carlo) sweep_run.phase == "running" → RunSpinner (Sweep) + grid_run.phase == "running" → RunSpinner (Grid Sweep) mc_run.phase == "ready" → BatchResultsPanel sweep_run.phase == "ready" → BatchResultsPanel + grid_run.phase == "ready" → BatchResultsPanel basic_run.phase == "ready" OR history → AnalysisPanel else → EmptyState """ @@ -27,7 +29,12 @@ from compute_permit_sim.vis.panels.batch_results import BatchResultsPanel from compute_permit_sim.vis.panels.config import ConfigPanel from compute_permit_sim.vis.state.history import session_history -from compute_permit_sim.vis.state.run_state import basic_run, mc_run, sweep_run +from compute_permit_sim.vis.state.run_state import ( + basic_run, + grid_run, + mc_run, + sweep_run, +) configure_logging() logger = logging.getLogger(__name__) @@ -109,12 +116,13 @@ def toggle_theme(): basic = basic_run.value mc = mc_run.value sw = sweep_run.value + gr = grid_run.value if basic.is_running: RunSpinner("Simulating\u2026") - elif mc.is_running or sw.is_running: + elif mc.is_running or sw.is_running or gr.is_running: RunSpinner("Running batch analysis\u2026") - elif mc.is_ready or sw.is_ready: + elif mc.is_ready or sw.is_ready or gr.is_ready: BatchResultsPanel() elif basic.is_ready or session_history.selected_run.value is not None: AnalysisPanel() diff --git a/src/compute_permit_sim/vis/panels/batch.py b/src/compute_permit_sim/vis/panels/batch.py index e906ed9..ad958e2 100644 --- a/src/compute_permit_sim/vis/panels/batch.py +++ b/src/compute_permit_sim/vis/panels/batch.py @@ -28,7 +28,7 @@ ) from compute_permit_sim.vis.components.history import UnifiedHistoryList from compute_permit_sim.vis.components.results import SidebarLabel -from compute_permit_sim.vis.state.run_state import RunState, mc_run, sweep_run +from compute_permit_sim.vis.state.run_state import RunState, grid_run, mc_run, sweep_run # Pre-built lookup map (module-level constant — registry never changes at runtime) _PARAM_MAP: dict[str, SweepParam] = {p.path: p for p in SWEEPABLE_PARAMS} @@ -38,6 +38,7 @@ # --------------------------------------------------------------------------- _mc_status = solara.reactive("") _sweep_status = solara.reactive("") +_grid_status = solara.reactive("") # --------------------------------------------------------------------------- # Background workers @@ -140,6 +141,59 @@ def _run_sweep_background( _sweep_status.set(f"Error: {e}") +def _run_grid_background( + scenario_name: str, + param_x: SweepParam, + param_y: SweepParam, + x_values: list[float], + y_values: list[float], + n_runs: int, +) -> None: + """Run 2D grid sweep off the event loop thread and update grid_run reactive.""" + from compute_permit_sim.schemas.batch import GridSweepResult + from compute_permit_sim.services.sweep import run_grid_sweep + + try: + config = _load_scenario_by_name(scenario_name) + + if config is None: + _grid_status.set(f"Scenario '{scenario_name}' not found.") + grid_run.set(RunState[GridSweepResult](phase="idle")) + return + + n_cells = len(x_values) * len(y_values) + _grid_status.set( + f"Grid {len(x_values)}×{len(y_values)} = {n_cells} cells × {n_runs} runs..." + ) + + result = run_grid_sweep( + config, + param_x_path=param_x.path, + param_y_path=param_y.path, + x_values=x_values, + y_values=y_values, + param_x_label=param_x.label, + param_y_label=param_y.label, + n_runs=n_runs, + ) + + from compute_permit_sim.vis.state.history import ( + session_history, # noqa: PLC0415 + ) + + session_history.add_batch_result(result) + grid_run.set(RunState[GridSweepResult](phase="ready", result=result)) + _grid_status.set( + f"Done: {len(x_values)}×{len(y_values)} grid — " + f"compliance {result.compliance_min:.1%}–{result.compliance_max:.1%}" + ) + except Exception as e: # noqa: BLE001 + from compute_permit_sim.schemas.batch import GridSweepResult + + grid_run.set(RunState[GridSweepResult](phase="idle")) + _grid_status.set(f"Error: {e}") + + # --------------------------------------------------------------------------- # Sub-components # --------------------------------------------------------------------------- @@ -391,9 +445,220 @@ def on_param_label_change(label: str) -> None: # --------------------------------------------------------------------------- +@solara.component +def _GridSweepCard(scenario_names: list[str]) -> Any: + """Sidebar card for configuring and launching a 2D grid sweep.""" + selected_scenario, set_selected_scenario = solara.use_state( + scenario_names[0] if scenario_names else "" + ) + + all_categories = categories() + + # --- X-axis param --- + cat_x, set_cat_x = solara.use_state(all_categories[0] if all_categories else "") + params_x = params_for_category(cat_x) + path_x, set_path_x = solara.use_state(params_x[0].path if params_x else "") + param_x = _PARAM_MAP.get(path_x) + min_x, set_min_x = solara.use_state(param_x.default_min if param_x else 0.0) + max_x, set_max_x = solara.use_state(param_x.default_max if param_x else 1.0) + step_x, set_step_x = solara.use_state(param_x.default_step if param_x else 0.1) + + # --- Y-axis param --- + cat_y, set_cat_y = solara.use_state(all_categories[0] if all_categories else "") + params_y = params_for_category(cat_y) + path_y, set_path_y = solara.use_state(params_y[0].path if params_y else "") + param_y = _PARAM_MAP.get(path_y) + min_y, set_min_y = solara.use_state(param_y.default_min if param_y else 0.0) + max_y, set_max_y = solara.use_state(param_y.default_max if param_y else 1.0) + step_y, set_step_y = solara.use_state(param_y.default_step if param_y else 0.1) + + n_runs, set_n_runs = solara.use_state(20) + + is_running = grid_run.value.is_running + status = _grid_status.value + + # Compute preview — both axes must be valid + n_pts_x, n_pts_y = 0, 0 + preview_error = "" + if param_x and step_x > 0 and min_x <= max_x: + try: + n_pts_x = len(generate_values(param_x, min_x, max_x, step_x)) + except Exception: + preview_error = "Invalid X range" + if param_y and step_y > 0 and min_y <= max_y: + try: + n_pts_y = len(generate_values(param_y, min_y, max_y, step_y)) + except Exception: + preview_error = "Invalid Y range" + + def _on_cat_x(cat: str) -> None: + set_cat_x(cat) + ps = params_for_category(cat) + if ps: + set_path_x(ps[0].path) + set_min_x(ps[0].default_min) + set_max_x(ps[0].default_max) + set_step_x(ps[0].default_step) + + def _on_path_x(label: str) -> None: + path = {p.label: p.path for p in params_for_category(cat_x)}.get(label, "") + set_path_x(path) + p = _PARAM_MAP.get(path) + if p: + set_min_x(p.default_min) + set_max_x(p.default_max) + set_step_x(p.default_step) + + def _on_cat_y(cat: str) -> None: + set_cat_y(cat) + ps = params_for_category(cat) + if ps: + set_path_y(ps[0].path) + set_min_y(ps[0].default_min) + set_max_y(ps[0].default_max) + set_step_y(ps[0].default_step) + + def _on_path_y(label: str) -> None: + path = {p.label: p.path for p in params_for_category(cat_y)}.get(label, "") + set_path_y(path) + p = _PARAM_MAP.get(path) + if p: + set_min_y(p.default_min) + set_max_y(p.default_max) + set_step_y(p.default_step) + + def on_run() -> None: + if not param_x or not param_y: + return + try: + x_vals = generate_values(param_x, min_x, max_x, step_x) + y_vals = generate_values(param_y, min_y, max_y, step_y) + except ValueError: + _grid_status.set("Invalid range — check min/max/step for both axes.") + return + from compute_permit_sim.schemas.batch import ( # noqa: PLC0415 + GridSweepResult, + MonteCarloResult, + SweepResult, + ) + + grid_run.set(RunState[GridSweepResult](phase="running")) + mc_run.set(RunState[MonteCarloResult](phase="idle")) + sweep_run.set(RunState[SweepResult](phase="idle")) + _grid_status.set("Starting...") + threading.Thread( + target=_run_grid_background, + args=(selected_scenario, param_x, param_y, x_vals, y_vals, n_runs), + daemon=True, + ).start() + + with solara.Card(title="Grid Sweep"): + if not scenario_names: + with solara.Column(classes=["sidebar-empty-text"]): + solara.Text("No scenarios found.") + return + + solara.Select( + label="Scenario", + values=scenario_names, + value=selected_scenario, + on_value=set_selected_scenario, + dense=True, + ) + + # ── X-axis ────────────────────────────────────────────────────────── + with solara.Column(classes=["sidebar-hint-text"]): + solara.Text("X-axis parameter") + with solara.Row(style="gap: 4px;"): + solara.Select( + label="Category", + values=all_categories, + value=cat_x, + on_value=_on_cat_x, + dense=True, + ) + labels_x = [p.label for p in params_for_category(cat_x)] + solara.Select( + label="Parameter", + values=labels_x, + value=param_x.label if param_x else (labels_x[0] if labels_x else ""), + on_value=_on_path_x, + dense=True, + ) + with solara.Row(style="gap: 4px;"): + unit_x = param_x.unit if param_x else "" + solara.InputFloat(label=f"Min ({unit_x})", value=min_x, on_value=set_min_x) + solara.InputFloat(label=f"Max ({unit_x})", value=max_x, on_value=set_max_x) + solara.InputFloat(label="Step", value=step_x, on_value=set_step_x) + + # ── Y-axis ────────────────────────────────────────────────────────── + with solara.Column(classes=["sidebar-hint-text"]): + solara.Text("Y-axis parameter") + with solara.Row(style="gap: 4px;"): + solara.Select( + label="Category", + values=all_categories, + value=cat_y, + on_value=_on_cat_y, + dense=True, + ) + labels_y = [p.label for p in params_for_category(cat_y)] + solara.Select( + label="Parameter", + values=labels_y, + value=param_y.label if param_y else (labels_y[0] if labels_y else ""), + on_value=_on_path_y, + dense=True, + ) + with solara.Row(style="gap: 4px;"): + unit_y = param_y.unit if param_y else "" + solara.InputFloat(label=f"Min ({unit_y})", value=min_y, on_value=set_min_y) + solara.InputFloat(label=f"Max ({unit_y})", value=max_y, on_value=set_max_y) + solara.InputFloat(label="Step", value=step_y, on_value=set_step_y) + + # ── Replications + simulation count preview ────────────────────────── + solara.SliderInt( + label=f"Runs per cell: {n_runs}", + value=n_runs, + on_value=set_n_runs, + min=5, + max=100, + step=5, + ) + if preview_error: + with solara.Column(classes=["sidebar-error-text"]): + solara.Text(preview_error) + elif n_pts_x > 0 and n_pts_y > 0: + total = n_pts_x * n_pts_y * n_runs + with solara.Column(classes=["sidebar-hint-text"]): + solara.Text( + f"{n_pts_x}\u00d7{n_pts_y} = {n_pts_x * n_pts_y} cells" + f" \u00d7 {n_runs} = {total:,} total simulations" + ) + + solara.Button( + "Running..." if is_running else "Run Grid Sweep", + on_click=on_run, + color="primary", + block=True, + disabled=is_running + or not selected_scenario + or not param_x + or not param_y + or n_pts_x == 0 + or n_pts_y == 0, + small=True, + ) + if status and ( + "Error" in status or "not found" in status or "Invalid" in status + ): + with solara.Column(classes=["sidebar-error-text"]): + solara.Text(status) + + @solara.component def BatchPanel() -> Any: - """Sidebar panel with Monte Carlo and Parameter Sweep configurators.""" + """Sidebar panel with Monte Carlo, Parameter Sweep, and Grid Sweep configurators.""" from compute_permit_sim.vis.state.history import session_history # noqa: PLC0415 # Use the same name map as LoadScenarioDialog for consistency @@ -403,6 +668,7 @@ def BatchPanel() -> Any: SidebarLabel("**BATCH ANALYSIS**") _MonteCarloCard(scenario_names=scenario_names) _SweepCard(scenario_names=scenario_names) + _GridSweepCard(scenario_names=scenario_names) # ── History — batch results + individual runs in one stream ──────── solara.Markdown("---") diff --git a/src/compute_permit_sim/vis/plotting.py b/src/compute_permit_sim/vis/plotting.py index ba5d4cc..d0ca874 100644 --- a/src/compute_permit_sim/vis/plotting.py +++ b/src/compute_permit_sim/vis/plotting.py @@ -15,6 +15,7 @@ import textwrap import matplotlib +import numpy as np import pandas as pd from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -642,15 +643,22 @@ def plot_mc_payoff_comparison(result) -> "Figure": return fig -def plot_sweep_curve(result, metric: str = "avg_compliance") -> "Figure": +def plot_sweep_curve( + result, + metric: str = "avg_compliance", + reference_lines: list[tuple[float, str, str]] | None = None, +) -> "Figure": """Plot a 1D parameter sweep curve: param value on X, metric on Y. Renders the mean as a line with ± 1 SD shading. Annotates the tipping - point (first value where compliance ≥ 95 %) if present. + point (first value where compliance ≥ 95 %) if present. Args: result: A ``SweepResult`` instance. metric: Attribute name on ``MonteCarloResult`` to plot (default: avg_compliance). + reference_lines: Optional list of ``(x_value, label, color)`` tuples + for annotating known calibration points (e.g. scenario pa values). + Each draws a vertical dotted line with a small text label. Returns: Matplotlib Figure. @@ -670,19 +678,7 @@ def plot_sweep_curve(result, metric: str = "avg_compliance") -> "Figure": color = CHART_COLOR_MAP.get("compliant", "#42A5F5") ax.plot(xs, means, color=color, linewidth=2, marker="o", markersize=5, label="Mean") - ax.fill_between(xs, lows, highs, alpha=0.18, color=color, label="± 1 SD") - - tp = result.tipping_point(threshold=0.95) - if tp is not None: - ax.axvline(tp, color="#FFA726", linewidth=1.5, linestyle="--") - ax.annotate( - f"Tipping ≈ {tp:.3f}", - xy=(tp, 0.95), - xytext=(tp, 0.70), - fontsize=8, - color="#FFA726", - arrowprops={"arrowstyle": "->", "color": "#FFA726"}, - ) + ax.fill_between(xs, lows, highs, alpha=0.18, color=color, label="\u00b1 1 SD") is_compliance = "compliance" in metric if is_compliance: @@ -694,10 +690,145 @@ def plot_sweep_curve(result, metric: str = "avg_compliance") -> "Figure": ax.set_xlabel(_wrap(result.param_label, width=40)) ax.set_title( - _wrap(f"Sensitivity: {result.param_label} — {result.scenario_name}"), + _wrap(f"Sensitivity: {result.param_label} \u2014 {result.scenario_name}"), fontsize=11, fontweight="600", ) ax.legend(fontsize=9) fig.tight_layout() return fig + + +def plot_sweep_heatmap( + compliance_grid: list[list[float]], + x_values: list[float], + y_values: list[float], + x_param_label: str = "Base Audit Rate \u03c0\u2080", + y_param_label: str = "Collateral K (M$)", + x_tick_labels: list[str] | None = None, + y_tick_labels: list[str] | None = None, + title: str | None = None, + highlight: tuple[float, float] | None = None, + highlight_label: str = "Calibration", +) -> "Figure": + """Heatmap of average compliance over a 2D parameter grid. + + Renders each cell with its mean compliance rate as a shaded colour and an + inline percentage annotation. Designed for joint-sensitivity analysis + (e.g. pa x K grid) and re-usable for any two-parameter sweep. + + Args: + compliance_grid: 2D list ``[y_idx][x_idx]`` of mean compliance fractions. + x_values: Parameter values along the x-axis (e.g. audit rates). + y_values: Parameter values along the y-axis (e.g. collateral amounts). + x_param_label: Human-readable x-axis label. + y_param_label: Human-readable y-axis label. + x_tick_labels: Optional custom tick labels for x-axis; defaults to + auto-formatted ``x_values`` as percentages. + y_tick_labels: Optional custom tick labels for y-axis; defaults to + auto-formatted ``y_values`` as dollar amounts. + title: Optional chart title. + highlight: Optional ``(x_val, y_val)`` calibration point to outline + with a red border. + highlight_label: Label shown adjacent to the highlighted cell. + + Returns: + Matplotlib Figure. + """ + import matplotlib.patches as mpatches + + fig, ax = create_figure(figsize=(7, 5)) + data = np.array(compliance_grid) # shape: (n_y, n_x) + + im = ax.imshow( + data, + aspect="auto", + origin="lower", + cmap="Blues", + vmin=0.0, + vmax=1.0, + interpolation="nearest", + ) + + # Colorbar with shared percent formatter + cbar = fig.colorbar( + im, ax=ax, format=matplotlib.ticker.PercentFormatter(xmax=1), shrink=0.85 + ) + cbar.set_label("Mean Compliance Rate", fontsize=10) + + # Tick labels — default to % for x (audit rate) and $M for y (collateral) + xt_labels = x_tick_labels or [f"{v:.0%}" for v in x_values] + yt_labels = y_tick_labels or [f"${v:.0f}M" for v in y_values] + ax.set_xticks(range(len(x_values))) + ax.set_xticklabels(xt_labels, fontsize=8, rotation=45, ha="right") + ax.set_yticks(range(len(y_values))) + ax.set_yticklabels(yt_labels, fontsize=8) + + # Per-cell compliance annotation + for yi in range(len(y_values)): + for xi in range(len(x_values)): + val = float(data[yi, xi]) + text_color = "white" if val > 0.65 else "#333333" + ax.text( + xi, + yi, + f"{val:.0%}", + ha="center", + va="center", + fontsize=7, + color=text_color, + fontweight="500", + ) + + # Optional highlight: red border around a calibration cell + if highlight is not None: + hx_val, hy_val = highlight + hx_idx = min(range(len(x_values)), key=lambda i: abs(x_values[i] - hx_val)) + hy_idx = min(range(len(y_values)), key=lambda i: abs(y_values[i] - hy_val)) + rect = mpatches.FancyBboxPatch( + (hx_idx - 0.45, hy_idx - 0.45), + 0.9, + 0.9, + boxstyle="square,pad=0", + linewidth=2.5, + edgecolor=CHART_COLOR_MAP.get("violator", "#EF5350"), + facecolor="none", + zorder=3, + ) + ax.add_patch(rect) + ax.text( + hx_idx, + hy_idx + 0.52, + highlight_label, + ha="center", + va="bottom", + fontsize=7, + color=CHART_COLOR_MAP.get("violator", "#EF5350"), + fontweight="bold", + zorder=4, + ) + + ax.set_xlabel(_wrap(x_param_label, width=40), fontsize=11, fontweight="500") + ax.set_ylabel(_wrap(y_param_label, width=30), fontsize=11, fontweight="500") + if title: + ax.set_title(_wrap(title), fontsize=11, fontweight="600") + + # Suppress grid — imshow cells provide visual separation + ax.grid(False) + fig.tight_layout() + return fig + + +def save_figure(fig: Figure, path: str, dpi: int = 150) -> None: + """Save a Figure to *path* using canonical export settings. + + Single source of truth for dpi and bbox behaviour across all scripts and + agent_workspace callers. Never call ``fig.savefig(...)`` directly in + workspace scripts — use this instead. + + Args: + fig: A ``matplotlib.figure.Figure`` returned by any plotting function. + path: Destination file path (PNG recommended). + dpi: Resolution; default 150 for paper-quality output. + """ + fig.savefig(path, dpi=dpi, bbox_inches="tight") diff --git a/src/compute_permit_sim/vis/state/history.py b/src/compute_permit_sim/vis/state/history.py index 30e5c85..fd91502 100644 --- a/src/compute_permit_sim/vis/state/history.py +++ b/src/compute_permit_sim/vis/state/history.py @@ -5,10 +5,14 @@ import solara from compute_permit_sim.schemas import SimulationRun -from compute_permit_sim.schemas.batch import MonteCarloResult, SweepResult - -# Union type for batch results — MC aggregate or sweep aggregate. -BatchResult = MonteCarloResult | SweepResult +from compute_permit_sim.schemas.batch import ( + GridSweepResult, + MonteCarloResult, + SweepResult, +) + +# Union type for batch results — MC aggregate, 1D sweep, or 2D grid sweep. +BatchResult = MonteCarloResult | SweepResult | GridSweepResult class SessionHistory: diff --git a/src/compute_permit_sim/vis/state/run_state.py b/src/compute_permit_sim/vis/state/run_state.py index aad1c45..126fe28 100644 --- a/src/compute_permit_sim/vis/state/run_state.py +++ b/src/compute_permit_sim/vis/state/run_state.py @@ -25,7 +25,11 @@ from pydantic import BaseModel, ConfigDict from compute_permit_sim.schemas import SimulationRun -from compute_permit_sim.schemas.batch import MonteCarloResult, SweepResult +from compute_permit_sim.schemas.batch import ( + GridSweepResult, + MonteCarloResult, + SweepResult, +) T = TypeVar("T") @@ -69,3 +73,8 @@ def is_ready(self) -> bool: sweep_run: solara.Reactive[RunState[SweepResult]] = solara.reactive( RunState[SweepResult]() ) + +#: 2D grid sweep batch run state +grid_run: solara.Reactive[RunState[GridSweepResult]] = solara.reactive( + RunState[GridSweepResult]() +) diff --git a/tests/factories.py b/tests/factories.py index 6ba5739..804ff05 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,6 +1,11 @@ """Test data factories for generating valid schema objects.""" -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from compute_permit_sim.schemas.batch import GridSweepResult from compute_permit_sim.schemas import ( AgentSnapshot, @@ -60,3 +65,32 @@ def create_scenario_config( } data = {**defaults, **kwargs} return ScenarioConfig(name=name, **data) + + +def create_grid_sweep_result( + n_x: int = 3, + n_y: int = 2, + scenario_name: str = "Test Scenario", +) -> "GridSweepResult": + """Create a minimal GridSweepResult for testing — no simulation run needed.""" + from compute_permit_sim.schemas.batch import GridSweepResult + + x_values = [float(i) * 0.1 for i in range(1, n_x + 1)] + y_values = [float(j) * 10.0 for j in range(1, n_y + 1)] + # grid[y_idx][x_idx] = synthetic compliance value in [0, 1] + grid = [ + [float(y_idx * n_x + x_idx) / (n_x * n_y) for x_idx in range(n_x)] + for y_idx in range(n_y) + ] + return GridSweepResult( + scenario_name=scenario_name, + param_x_path="audit.base_prob", + param_x_label="Base Audit Probability", + param_y_path="collateral_amount", + param_y_label="Collateral K (M$)", + config=create_scenario_config(name=scenario_name), + x_values=x_values, + y_values=y_values, + grid=grid, + n_runs=5, + ) diff --git a/tests/services/test_sweep.py b/tests/services/test_sweep.py index 8597bfb..3f7dd26 100644 --- a/tests/services/test_sweep.py +++ b/tests/services/test_sweep.py @@ -86,3 +86,103 @@ def test_default_param_label(self) -> None: cfg = self._base() result = run_sweep(cfg, "audit.base_prob", [0.1], n_runs=2) assert result.param_label == "audit.base_prob" + + +class TestRunGridSweep: + def _base(self) -> ScenarioConfig: + return ScenarioConfig(n_agents=4, steps=3) + + def test_grid_shape(self) -> None: + from compute_permit_sim.services.sweep import run_grid_sweep + + x_values = [0.05, 0.10, 0.15] + y_values = [0.0, 10.0] + result = run_grid_sweep( + self._base(), + "audit.base_prob", + "collateral_amount", + x_values, + y_values, + n_runs=2, + ) + assert len(result.grid) == len(y_values) + assert all(len(row) == len(x_values) for row in result.grid) + + def test_grid_values_in_range(self) -> None: + from compute_permit_sim.services.sweep import run_grid_sweep + + result = run_grid_sweep( + self._base(), + "audit.base_prob", + "collateral_amount", + [0.05, 0.20], + [0.0, 5.0], + n_runs=2, + ) + for row in result.grid: + for v in row: + assert 0.0 <= v <= 1.0 + + def test_metadata(self) -> None: + from compute_permit_sim.services.sweep import run_grid_sweep + + cfg = ScenarioConfig(name="GridTest", n_agents=2, steps=2) + result = run_grid_sweep( + cfg, + "audit.base_prob", + "collateral_amount", + [0.1], + [0.0], + param_x_label="X Label", + param_y_label="Y Label", + n_runs=2, + ) + assert result.scenario_name == "GridTest" + assert result.param_x_path == "audit.base_prob" + assert result.param_y_path == "collateral_amount" + assert result.param_x_label == "X Label" + assert result.param_y_label == "Y Label" + assert result.n_runs == 2 + + def test_compliance_at(self) -> None: + from compute_permit_sim.services.sweep import run_grid_sweep + + x_vals = [0.05, 0.20] + y_vals = [0.0, 10.0] + result = run_grid_sweep( + self._base(), + "audit.base_prob", + "collateral_amount", + x_vals, + y_vals, + n_runs=2, + ) + for x in x_vals: + for y in y_vals: + val = result.compliance_at(x, y) + assert val is not None + assert 0.0 <= val <= 1.0 + # Non-existent cell returns None + assert result.compliance_at(0.99, 99.0) is None + + def test_reproducible(self) -> None: + from compute_permit_sim.services.sweep import run_grid_sweep + + seeds = [0, 1, 2] + r1 = run_grid_sweep( + self._base(), + "audit.base_prob", + "collateral_amount", + [0.05], + [0.0], + seeds=seeds, + ) + r2 = run_grid_sweep( + self._base(), + "audit.base_prob", + "collateral_amount", + [0.05], + [0.0], + seeds=seeds, + ) + assert abs(r1.grid[0][0] - r2.grid[0][0]) < 1e-10 diff --git a/tests/vis/test_export.py b/tests/vis/test_export.py index 2870237..5395d12 100644 --- a/tests/vis/test_export.py +++ b/tests/vis/test_export.py @@ -109,3 +109,51 @@ def test_export_run_to_excel_creates_file(sample_run: SimulationRun) -> None: # Header row is parsed, we expect 2 agents assert len(df_agents) == 2 assert "Agent's base economic value (v_i)" in df_agents.columns + + +# --------------------------------------------------------------------------- +# Grid sweep export tests +# --------------------------------------------------------------------------- + + +def test_export_grid_sweep_to_csv_bytes() -> None: + """CSV export returns bytes with n_x * n_y rows and expected columns.""" + from compute_permit_sim.vis.export import export_grid_sweep_to_csv + from tests.factories import create_grid_sweep_result + + n_x, n_y = 3, 2 + result = create_grid_sweep_result(n_x=n_x, n_y=n_y) + csv_bytes = export_grid_sweep_to_csv(result, output_path="") + assert isinstance(csv_bytes, bytes) + + import io + + df = pd.read_csv(io.BytesIO(csv_bytes)) + assert len(df) == n_x * n_y + required_cols = { + "param_x_path", + "param_x_value", + "param_y_path", + "param_y_value", + "n_runs", + "compliance_rate", + } + assert required_cols.issubset(set(df.columns)) + + +def test_export_grid_sweep_to_excel_bytes() -> None: + """Excel export returns non-empty bytes with Config, Grid, and Heatmap sheets.""" + from compute_permit_sim.vis.export import export_grid_sweep_to_excel + from tests.factories import create_grid_sweep_result + + result = create_grid_sweep_result(n_x=2, n_y=2) + xlsx_bytes = export_grid_sweep_to_excel(result, output_path="") + assert isinstance(xlsx_bytes, bytes) + assert len(xlsx_bytes) > 0 + + import io + + with pd.ExcelFile(io.BytesIO(xlsx_bytes)) as xl: + assert "Config" in xl.sheet_names + assert "Grid" in xl.sheet_names + assert "Heatmap" in xl.sheet_names