From 3f613320f8246265511e465f1f9d7b9ccd9babed Mon Sep 17 00:00:00 2001 From: Yanpeng Wang Date: Tue, 19 May 2026 18:46:18 +0800 Subject: [PATCH 1/2] Add cost estimation feature Introduce per-request and per-session cost estimates based on model pricing. Adds pricing.py with a lookup table for Claude and GPT models, integrates cost into the token_counter analyzer summary, CLI reporter output, and diagnosis JSON under a top-level "cost" key. Unknown models are silently skipped (no errors, no cost field). --- .../analyzers/token_counter.py | 11 + src/context_profiler/diagnostics.py | 59 +++++ src/context_profiler/pricing.py | 109 +++++++++ .../reporters/cli_reporter.py | 9 + tests/test_cost_estimation.py | 215 ++++++++++++++++++ 5 files changed, 403 insertions(+) create mode 100644 src/context_profiler/pricing.py create mode 100644 tests/test_cost_estimation.py diff --git a/src/context_profiler/analyzers/token_counter.py b/src/context_profiler/analyzers/token_counter.py index 16a62a7..e1e5017 100644 --- a/src/context_profiler/analyzers/token_counter.py +++ b/src/context_profiler/analyzers/token_counter.py @@ -6,6 +6,7 @@ from context_profiler.analyzers.base import AnalyzerResult, BaseAnalyzer from context_profiler.models import APIRequest, BlockType, Role +from context_profiler.pricing import estimate_cost class TokenCounterAnalyzer(BaseAnalyzer): @@ -60,12 +61,19 @@ def analyze(self, request: APIRequest) -> AnalyzerResult: tool_use_tokens = by_content_type.get("tool_use", 0) tool_result_tokens = by_content_type.get("tool_result", 0) + cost = estimate_cost( + input_tokens=total_tokens, + output_tokens=0, + model=request.model, + ) + summary = { "total_input_tokens": total_tokens, "message_tokens": total_tokens - tool_def_tokens, "tool_definition_tokens": tool_def_tokens, "system_prompt_tokens": request.system_prompt_tokens, "source_format": request.source_format, + "model": request.model, "by_role": dict(by_role), "by_content_type": dict(by_content_type), "tool_use_tokens": tool_use_tokens, @@ -76,6 +84,9 @@ def analyze(self, request: APIRequest) -> AnalyzerResult: "tool_definitions": tool_defs_detail, } + if cost is not None: + summary["cost"] = cost + warnings = [] if tool_def_tokens > total_tokens * 0.3: warnings.append( diff --git a/src/context_profiler/diagnostics.py b/src/context_profiler/diagnostics.py index 307a1c4..0e19be8 100644 --- a/src/context_profiler/diagnostics.py +++ b/src/context_profiler/diagnostics.py @@ -7,6 +7,7 @@ from context_profiler.context_diff import analyze_context_diff from context_profiler.formats import describe_format from context_profiler.models import Session +from context_profiler.pricing import estimate_cost from context_profiler.profiler import ProfileResult from context_profiler.session_insights import analyze_session_insights @@ -158,6 +159,37 @@ def diagnose_result(result: ProfileResult, session: Session | None = None) -> di "recommendation": "Move stable repeated arguments into references or shorter identifiers.", }) + # Context overflow risk from budget forecast + forecast = session_insights.get("budget_forecast") + if forecast and forecast.get("estimated_overflow_turn") is not None: + current_turn_count = len(session.requests) if session else 0 + overflow_turn = forecast["estimated_overflow_turn"] + utilization = forecast["current_utilization"] + # Trigger if overflow is within 2x the current turn count + if current_turn_count > 0 and overflow_turn <= current_turn_count * 2: + if utilization > 0.8: + severity = "critical" + elif utilization > 0.5: + severity = "warning" + else: + severity = "info" + issues.append({ + "code": "CONTEXT_OVERFLOW_RISK", + "severity": severity, + "message": "Context is projected to overflow the model window at the current growth rate.", + "evidence": { + "growth_rate_per_turn": forecast["growth_rate_per_turn"], + "current_utilization": forecast["current_utilization"], + "estimated_overflow_turn": forecast["estimated_overflow_turn"], + "context_window_tokens": forecast["context_window_tokens"], + "model": forecast["model"], + }, + "recommendation": "Consider compacting earlier turns, summarizing tool results, or removing stale context before the window fills.", + }) + + # Cost estimation + cost_info = _compute_cost(result, session) + return { "schema_version": "0.1", "source": result.source, @@ -168,12 +200,39 @@ def diagnose_result(result: ProfileResult, session: Session | None = None) -> di "warnings": result.all_warnings, }, "issues": issues, + "cost": cost_info, "diff_summary": diff["diff_summary"], "diff_hints": diff["diff_hints"] + session_insights["hints"], "session_insights": session_insights, } +def _compute_cost(result: ProfileResult, session: Session | None = None) -> dict[str, Any] | None: + """Compute cost estimation for the profiled request or session.""" + token = result.analyzer_results.get("token_counter") + if not token: + return None + + summary = token.summary + model = summary.get("model", "unknown") + + if session and session.requests: + # Session mode: sum input tokens across all requests + total_input = sum(req.total_input_tokens for req in session.requests) + cost = estimate_cost(input_tokens=total_input, model=model) + if cost: + cost["mode"] = "session" + cost["num_requests"] = len(session.requests) + return cost + else: + # Snapshot mode: single request + total_input = summary.get("total_input_tokens", 0) + cost = estimate_cost(input_tokens=total_input, model=model) + if cost: + cost["mode"] = "snapshot" + return cost + + def _analysis_scope(result: ProfileResult, session: Session | None = None) -> dict[str, Any]: source_format = _source_format(result) or ( session.metadata.get("source_format") if session is not None else None diff --git a/src/context_profiler/pricing.py b/src/context_profiler/pricing.py new file mode 100644 index 0000000..567b772 --- /dev/null +++ b/src/context_profiler/pricing.py @@ -0,0 +1,109 @@ +"""Model pricing table for cost estimation. + +Maps model name patterns to input/output pricing per 1M tokens (USD). +Prices are approximate and should be updated as providers change rates. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ModelPricing: + """Pricing for a single model tier.""" + + input_per_1m: float # USD per 1M input tokens + output_per_1m: float # USD per 1M output tokens + display_name: str + + +# Patterns are matched in order; first match wins. +# Use lowercase substrings for matching against model identifiers. +PRICING_TABLE: list[tuple[list[str], ModelPricing]] = [ + # Claude models + ( + ["claude-opus-4", "claude-4-opus"], + ModelPricing(input_per_1m=15.0, output_per_1m=75.0, display_name="Claude Opus 4"), + ), + ( + ["claude-sonnet-4", "claude-4-sonnet"], + ModelPricing(input_per_1m=3.0, output_per_1m=15.0, display_name="Claude Sonnet 4"), + ), + ( + ["claude-3-5-sonnet", "claude-3.5-sonnet"], + ModelPricing(input_per_1m=3.0, output_per_1m=15.0, display_name="Claude 3.5 Sonnet"), + ), + ( + ["claude-3-5-haiku", "claude-3.5-haiku"], + ModelPricing(input_per_1m=0.80, output_per_1m=4.0, display_name="Claude 3.5 Haiku"), + ), + ( + ["claude-3-opus"], + ModelPricing(input_per_1m=15.0, output_per_1m=75.0, display_name="Claude 3 Opus"), + ), + ( + ["claude-3-sonnet"], + ModelPricing(input_per_1m=3.0, output_per_1m=15.0, display_name="Claude 3 Sonnet"), + ), + ( + ["claude-3-haiku"], + ModelPricing(input_per_1m=0.25, output_per_1m=1.25, display_name="Claude 3 Haiku"), + ), + # GPT models + ( + ["gpt-4o-mini"], + ModelPricing(input_per_1m=0.15, output_per_1m=0.60, display_name="GPT-4o mini"), + ), + ( + ["gpt-4o"], + ModelPricing(input_per_1m=2.50, output_per_1m=10.0, display_name="GPT-4o"), + ), + ( + ["gpt-4-turbo"], + ModelPricing(input_per_1m=10.0, output_per_1m=30.0, display_name="GPT-4 Turbo"), + ), +] + + +def lookup_pricing(model: str) -> ModelPricing | None: + """Find pricing for a model by matching name patterns. + + Returns None if no match is found. + """ + if not model or model == "unknown": + return None + + model_lower = model.lower() + for patterns, pricing in PRICING_TABLE: + for pattern in patterns: + if pattern in model_lower: + return pricing + return None + + +def estimate_cost( + input_tokens: int, + output_tokens: int = 0, + model: str = "unknown", +) -> dict[str, Any] | None: + """Estimate cost for a request given token counts and model. + + Returns a dict with cost breakdown, or None if model is unknown. + """ + pricing = lookup_pricing(model) + if pricing is None: + return None + + input_cost = (input_tokens / 1_000_000) * pricing.input_per_1m + output_cost = (output_tokens / 1_000_000) * pricing.output_per_1m + + return { + "estimated_input_cost_usd": round(input_cost, 6), + "estimated_output_cost_usd": round(output_cost, 6), + "estimated_total_cost_usd": round(input_cost + output_cost, 6), + "estimated_model": pricing.display_name, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + } diff --git a/src/context_profiler/reporters/cli_reporter.py b/src/context_profiler/reporters/cli_reporter.py index 887938f..e495481 100644 --- a/src/context_profiler/reporters/cli_reporter.py +++ b/src/context_profiler/reporters/cli_reporter.py @@ -114,6 +114,15 @@ def _render_token_summary(console: Console, summary: dict) -> None: tool_table.add_row(f" {tool_name}", _format_tokens(tokens), _pct(tokens, total)) console.print(tool_table) + cost = summary.get("cost") + if cost: + console.print() + console.print("[bold] Estimated Cost[/bold]") + model_name = cost.get("estimated_model", "unknown") + input_cost = cost.get("estimated_input_cost_usd", 0) + console.print(f" Model: {model_name}") + console.print(f" Input cost: ${input_cost:.4f}") + def _render_timeline(console: Console, timeline: list[dict]) -> None: console.print() diff --git a/tests/test_cost_estimation.py b/tests/test_cost_estimation.py new file mode 100644 index 0000000..c25c60a --- /dev/null +++ b/tests/test_cost_estimation.py @@ -0,0 +1,215 @@ +"""Tests for cost estimation feature.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from context_profiler.pricing import estimate_cost, lookup_pricing, ModelPricing +from context_profiler.models import APIRequest, Message, ContentBlock, BlockType, Role, Session +from context_profiler.analyzers.token_counter import TokenCounterAnalyzer +from context_profiler.diagnostics import diagnose_result +from context_profiler.profiler import profile_request, profile_session + +FIXTURES = Path(__file__).parent / "fixtures" + + +class TestLookupPricing: + """Test model name pattern matching.""" + + def test_claude_sonnet_35(self): + pricing = lookup_pricing("claude-3-5-sonnet-20241022") + assert pricing is not None + assert pricing.display_name == "Claude 3.5 Sonnet" + assert pricing.input_per_1m == 3.0 + + def test_claude_opus_4(self): + pricing = lookup_pricing("claude-opus-4-20250514") + assert pricing is not None + assert pricing.display_name == "Claude Opus 4" + assert pricing.input_per_1m == 15.0 + + def test_claude_sonnet_4(self): + pricing = lookup_pricing("claude-sonnet-4-20250514") + assert pricing is not None + assert pricing.display_name == "Claude Sonnet 4" + + def test_claude_haiku_35(self): + pricing = lookup_pricing("claude-3-5-haiku-20241022") + assert pricing is not None + assert pricing.display_name == "Claude 3.5 Haiku" + + def test_gpt4o(self): + pricing = lookup_pricing("gpt-4o-2024-08-06") + assert pricing is not None + assert pricing.display_name == "GPT-4o" + assert pricing.input_per_1m == 2.50 + + def test_gpt4o_mini(self): + pricing = lookup_pricing("gpt-4o-mini-2024-07-18") + assert pricing is not None + assert pricing.display_name == "GPT-4o mini" + assert pricing.input_per_1m == 0.15 + + def test_gpt4_turbo(self): + pricing = lookup_pricing("gpt-4-turbo-2024-04-09") + assert pricing is not None + assert pricing.display_name == "GPT-4 Turbo" + + def test_unknown_model_returns_none(self): + assert lookup_pricing("unknown") is None + assert lookup_pricing("") is None + assert lookup_pricing("some-custom-model") is None + + def test_case_insensitive(self): + pricing = lookup_pricing("Claude-3-5-Sonnet-20241022") + assert pricing is not None + assert pricing.display_name == "Claude 3.5 Sonnet" + + +class TestEstimateCost: + """Test cost calculation.""" + + def test_basic_cost_calculation(self): + cost = estimate_cost(input_tokens=1_000_000, model="gpt-4o") + assert cost is not None + assert cost["estimated_input_cost_usd"] == 2.50 + assert cost["estimated_output_cost_usd"] == 0.0 + assert cost["estimated_total_cost_usd"] == 2.50 + assert cost["estimated_model"] == "GPT-4o" + + def test_fractional_tokens(self): + cost = estimate_cost(input_tokens=50_000, model="claude-3-5-sonnet-20241022") + assert cost is not None + # 50K tokens at $3/1M = $0.15 + assert cost["estimated_input_cost_usd"] == 0.15 + + def test_with_output_tokens(self): + cost = estimate_cost(input_tokens=100_000, output_tokens=50_000, model="gpt-4o") + assert cost is not None + # Input: 100K at $2.50/1M = $0.25 + # Output: 50K at $10/1M = $0.50 + assert cost["estimated_input_cost_usd"] == 0.25 + assert cost["estimated_output_cost_usd"] == 0.5 + assert cost["estimated_total_cost_usd"] == 0.75 + + def test_unknown_model_returns_none(self): + cost = estimate_cost(input_tokens=1000, model="unknown") + assert cost is None + + def test_zero_tokens(self): + cost = estimate_cost(input_tokens=0, model="gpt-4o") + assert cost is not None + assert cost["estimated_input_cost_usd"] == 0.0 + + +class TestTokenCounterCostIntegration: + """Test that cost appears in token_counter analyzer output.""" + + def _make_request(self, model: str, num_blocks: int = 3) -> APIRequest: + blocks = [ + ContentBlock(block_type=BlockType.TEXT, text="hello " * 100, token_count=100) + for _ in range(num_blocks) + ] + messages = [Message(role=Role.USER, blocks=blocks, index=0)] + return APIRequest(messages=messages, model=model) + + def test_cost_in_summary_known_model(self): + req = self._make_request("claude-3-5-sonnet-20241022") + analyzer = TokenCounterAnalyzer() + result = analyzer.analyze(req) + assert "cost" in result.summary + cost = result.summary["cost"] + assert cost["estimated_model"] == "Claude 3.5 Sonnet" + assert cost["estimated_input_cost_usd"] > 0 + + def test_no_cost_for_unknown_model(self): + req = self._make_request("unknown") + analyzer = TokenCounterAnalyzer() + result = analyzer.analyze(req) + assert "cost" not in result.summary + + +class TestDiagnosticsCost: + """Test that cost appears in diagnosis JSON output.""" + + def _make_session(self, model: str, num_requests: int = 3) -> Session: + requests = [] + for i in range(num_requests): + blocks = [ + ContentBlock(block_type=BlockType.TEXT, text="x " * 500, token_count=500) + ] + messages = [Message(role=Role.USER, blocks=blocks, index=0)] + req = APIRequest(messages=messages, model=model, request_index=i) + requests.append(req) + return Session(requests=requests) + + def test_cost_in_diagnosis_snapshot(self): + blocks = [ + ContentBlock(block_type=BlockType.TEXT, text="hello " * 100, token_count=1000) + ] + messages = [Message(role=Role.USER, blocks=blocks, index=0)] + req = APIRequest(messages=messages, model="gpt-4o") + result = profile_request(req, source="test") + diagnosis = diagnose_result(result) + assert "cost" in diagnosis + cost = diagnosis["cost"] + assert cost["estimated_model"] == "GPT-4o" + assert cost["mode"] == "snapshot" + + def test_cost_in_diagnosis_session(self): + session = self._make_session("claude-3-5-sonnet-20241022", num_requests=3) + result = profile_session(session, source="test") + diagnosis = diagnose_result(result, session=session) + assert "cost" in diagnosis + cost = diagnosis["cost"] + assert cost["estimated_model"] == "Claude 3.5 Sonnet" + assert cost["mode"] == "session" + assert cost["num_requests"] == 3 + # 3 requests x 500 tokens = 1500 tokens at $3/1M + assert cost["input_tokens"] == 1500 + assert cost["estimated_input_cost_usd"] > 0 + + def test_no_cost_for_unknown_model_in_diagnosis(self): + blocks = [ + ContentBlock(block_type=BlockType.TEXT, text="hello", token_count=100) + ] + messages = [Message(role=Role.USER, blocks=blocks, index=0)] + req = APIRequest(messages=messages, model="unknown") + result = profile_request(req, source="test") + diagnosis = diagnose_result(result) + assert diagnosis["cost"] is None + + +class TestCLIReporterCost: + """Test that cost shows up in CLI output.""" + + def test_cost_in_cli_output(self): + from click.testing import CliRunner + from context_profiler.cli import main + + fixture = FIXTURES / "repeated_tool_calls.json" + # The fixture uses "gpt-4" which isn't in our pricing table, + # so let's create a temp fixture with a known model + runner = CliRunner() + with runner.isolated_filesystem(): + data = json.loads(fixture.read_text()) + data["model"] = "gpt-4o" + Path("test_request.json").write_text(json.dumps(data)) + result = runner.invoke(main, ["analyze", "test_request.json", "--format", "openai"]) + assert result.exit_code == 0 + assert "Estimated Cost" in result.output + assert "GPT-4o" in result.output + + def test_no_cost_line_for_unknown_model(self): + from click.testing import CliRunner + from context_profiler.cli import main + + fixture = FIXTURES / "repeated_tool_calls.json" + runner = CliRunner() + # The fixture uses "gpt-4" which is not in our pricing table + result = runner.invoke(main, ["analyze", str(fixture), "--format", "openai"]) + assert result.exit_code == 0 + assert "Estimated Cost" not in result.output From c8b8fff20e0ca270c18805fdd97c14ce04674ba5 Mon Sep 17 00:00:00 2001 From: Yanpeng Wang Date: Tue, 19 May 2026 21:58:03 +0800 Subject: [PATCH 2/2] Add context budget forecasting and overflow risk detection Predict when a session will hit the context window limit based on turn-over-turn token growth rate. Adds CONTEXT_OVERFLOW_RISK diagnostic issue when overflow is projected within 2x the current turn count. New fields in session_insights.budget_forecast: - growth_rate_per_turn - current_utilization - estimated_overflow_turn - context_window_tokens - model (matched family) --- src/context_profiler/session_insights.py | 68 ++++++ tests/test_budget_forecast.py | 277 +++++++++++++++++++++++ 2 files changed, 345 insertions(+) create mode 100644 tests/test_budget_forecast.py diff --git a/src/context_profiler/session_insights.py b/src/context_profiler/session_insights.py index d26f7ac..6869ab6 100644 --- a/src/context_profiler/session_insights.py +++ b/src/context_profiler/session_insights.py @@ -22,6 +22,17 @@ _COMPRESSION_DROP_MIN_TOKENS = 5_000 _COMPRESSION_DROP_RATIO = 0.15 +# Context window sizes by model family (tokens) +# Ordered longest-prefix-first for correct matching +_CONTEXT_WINDOW_SIZES: list[tuple[str, int]] = [ + ("gpt-4o-mini", 128_000), + ("gpt-4-turbo", 128_000), + ("gpt-4o", 128_000), + ("claude", 200_000), +] +_DEFAULT_CONTEXT_WINDOW = 128_000 +_OVERFLOW_RISK_HORIZON_MULTIPLIER = 2 # warn if overflow within 2x current turn count + def analyze_session_insights(session: Session | None) -> dict[str, Any]: """Summarize session-level carryover, budget, and artifact lifecycle signals.""" @@ -41,6 +52,7 @@ def analyze_session_insights(session: Session | None) -> dict[str, Any]: artifacts = _artifact_lifecycles(blocks_by_request) artifact_duplications = _artifact_duplications(session) propagation = _propagation_graph(blocks_by_request) + forecast = budget_forecast(session) hints = _build_hints(carryover, budget_events, artifacts, artifact_duplications) return { @@ -49,6 +61,7 @@ def analyze_session_insights(session: Session | None) -> dict[str, Any]: "artifact_lifecycles": artifacts, "artifact_duplications": artifact_duplications, "propagation": propagation, + "budget_forecast": forecast, "hints": hints, } @@ -525,3 +538,58 @@ def _find_first_key(value: Any, keys: tuple[str, ...]) -> Any | None: if found is not None: return found return None + + +# --------------------------------------------------------------------------- +# Budget Forecast +# --------------------------------------------------------------------------- + + +def _resolve_context_window(model: str) -> tuple[str, int]: + """Match a model string to a known context window size. + + Returns (matched_model_family, context_window_tokens). + """ + lowered = model.lower() + for prefix, size in _CONTEXT_WINDOW_SIZES: + if prefix in lowered: + return prefix, size + return "default", _DEFAULT_CONTEXT_WINDOW + + +def budget_forecast(session: Session) -> dict[str, Any] | None: + """Predict when a session will hit the context window limit. + + Returns None for sessions with fewer than 2 requests (not enough data). + """ + if session is None or len(session.requests) < 2: + return None + + token_counts = [req.total_input_tokens for req in session.requests] + num_turns = len(token_counts) + + # Determine model from the last request (most representative) + model_raw = session.requests[-1].model or "unknown" + matched_model, context_window = _resolve_context_window(model_raw) + + # Calculate turn-over-turn deltas for growth rate + deltas = [token_counts[i] - token_counts[i - 1] for i in range(1, num_turns)] + growth_rate = sum(deltas) / len(deltas) if deltas else 0.0 + + current_tokens = token_counts[-1] + current_utilization = current_tokens / context_window if context_window else 0.0 + + # Predict overflow turn (linear extrapolation from current position) + estimated_overflow_turn: int | None = None + if growth_rate > 0: + remaining = context_window - current_tokens + turns_until_overflow = remaining / growth_rate + estimated_overflow_turn = num_turns + int(turns_until_overflow) + + return { + "growth_rate_per_turn": round(growth_rate, 1), + "current_utilization": round(current_utilization, 4), + "estimated_overflow_turn": estimated_overflow_turn, + "context_window_tokens": context_window, + "model": matched_model, + } diff --git a/tests/test_budget_forecast.py b/tests/test_budget_forecast.py new file mode 100644 index 0000000..890d4ff --- /dev/null +++ b/tests/test_budget_forecast.py @@ -0,0 +1,277 @@ +"""Tests for context budget forecasting.""" + +from context_profiler.diagnostics import diagnose_result +from context_profiler.models import APIRequest, BlockType, ContentBlock, Message, Role, Session +from context_profiler.profiler import profile_session +from context_profiler.session_insights import ( + _resolve_context_window, + analyze_session_insights, + budget_forecast, +) + + +def _request(index: int, token_count: int, model: str = "gpt-4o") -> APIRequest: + """Create a minimal request with a given total token count.""" + messages = [ + Message( + role=Role.USER, + blocks=[ContentBlock(BlockType.TEXT, "content", token_count=token_count)], + index=0, + ) + ] + return APIRequest(messages=messages, request_index=index, model=model, source_format="openai") + + +# --------------------------------------------------------------------------- +# _resolve_context_window +# --------------------------------------------------------------------------- + + +def test_resolve_context_window_claude(): + model, size = _resolve_context_window("claude-3-opus-20240229") + assert model == "claude" + assert size == 200_000 + + +def test_resolve_context_window_gpt4o(): + model, size = _resolve_context_window("gpt-4o-2024-05-13") + assert model == "gpt-4o" + assert size == 128_000 + + +def test_resolve_context_window_gpt4_turbo(): + model, size = _resolve_context_window("gpt-4-turbo-preview") + assert model == "gpt-4-turbo" + assert size == 128_000 + + +def test_resolve_context_window_gpt4o_mini(): + model, size = _resolve_context_window("gpt-4o-mini") + assert model == "gpt-4o-mini" + assert size == 128_000 + + +def test_resolve_context_window_unknown_falls_back(): + model, size = _resolve_context_window("some-custom-model") + assert model == "default" + assert size == 128_000 + + +# --------------------------------------------------------------------------- +# budget_forecast — basic behavior +# --------------------------------------------------------------------------- + + +def test_budget_forecast_returns_none_for_single_request(): + session = Session(requests=[_request(0, 10_000)]) + assert budget_forecast(session) is None + + +def test_budget_forecast_returns_none_for_empty_session(): + session = Session(requests=[]) + assert budget_forecast(session) is None + + +def test_budget_forecast_returns_none_for_none_session(): + assert budget_forecast(None) is None + + +def test_budget_forecast_linear_growth(): + """With steady 10K growth per turn, should predict overflow correctly.""" + session = Session(requests=[ + _request(0, 10_000), + _request(1, 20_000), + _request(2, 30_000), + _request(3, 40_000), + _request(4, 50_000), + ]) + + forecast = budget_forecast(session) + + assert forecast is not None + assert forecast["growth_rate_per_turn"] == 10_000.0 + assert forecast["model"] == "gpt-4o" + assert forecast["context_window_tokens"] == 128_000 + # Current utilization: 50000 / 128000 + assert abs(forecast["current_utilization"] - 50_000 / 128_000) < 0.001 + # Remaining: 78000 tokens, at 10K/turn = 7.8 turns from turn 5 + # So overflow at turn 5 + 7 = 12 + assert forecast["estimated_overflow_turn"] == 12 + + +def test_budget_forecast_no_growth(): + """Flat token usage should not predict overflow.""" + session = Session(requests=[ + _request(0, 50_000), + _request(1, 50_000), + _request(2, 50_000), + ]) + + forecast = budget_forecast(session) + + assert forecast is not None + assert forecast["growth_rate_per_turn"] == 0.0 + assert forecast["estimated_overflow_turn"] is None + + +def test_budget_forecast_shrinking_context(): + """Decreasing token usage should not predict overflow.""" + session = Session(requests=[ + _request(0, 80_000), + _request(1, 60_000), + _request(2, 40_000), + ]) + + forecast = budget_forecast(session) + + assert forecast is not None + assert forecast["growth_rate_per_turn"] < 0 + assert forecast["estimated_overflow_turn"] is None + + +def test_budget_forecast_claude_model_uses_200k_window(): + session = Session(requests=[ + _request(0, 50_000, model="claude-3-sonnet-20240229"), + _request(1, 100_000, model="claude-3-sonnet-20240229"), + ]) + + forecast = budget_forecast(session) + + assert forecast is not None + assert forecast["context_window_tokens"] == 200_000 + assert forecast["model"] == "claude" + # Growth: 50K/turn, remaining: 100K, overflow at turn 2 + 2 = 4 + assert forecast["estimated_overflow_turn"] == 4 + + +# --------------------------------------------------------------------------- +# Integration with analyze_session_insights +# --------------------------------------------------------------------------- + + +def test_session_insights_includes_budget_forecast(): + session = Session(requests=[ + _request(0, 10_000), + _request(1, 20_000), + _request(2, 30_000), + ]) + + insights = analyze_session_insights(session) + + assert "budget_forecast" in insights + forecast = insights["budget_forecast"] + assert forecast is not None + assert "growth_rate_per_turn" in forecast + assert "current_utilization" in forecast + assert "estimated_overflow_turn" in forecast + assert "context_window_tokens" in forecast + assert "model" in forecast + + +def test_session_insights_budget_forecast_none_for_single_request(): + session = Session(requests=[_request(0, 10_000)]) + insights = analyze_session_insights(session) + assert insights["budget_forecast"] is None + + +# --------------------------------------------------------------------------- +# CONTEXT_OVERFLOW_RISK diagnostic +# --------------------------------------------------------------------------- + + +def test_diagnostic_context_overflow_risk_critical(): + """High utilization + imminent overflow triggers critical issue.""" + # 5 turns, 22K growth/turn, at 110K now on 128K window + # utilization = 110000/128000 = 0.859 > 0.8 => critical + # overflow at turn 5 + int(18000/22000) = 5 + 0 = 5 + # 2x current = 10, overflow 5 <= 10 => triggers + session = Session(requests=[ + _request(0, 22_000), + _request(1, 44_000), + _request(2, 66_000), + _request(3, 88_000), + _request(4, 110_000), + ]) + + result = profile_session(session) + diagnosis = diagnose_result(result, session=session) + + overflow_issues = [i for i in diagnosis["issues"] if i["code"] == "CONTEXT_OVERFLOW_RISK"] + assert len(overflow_issues) == 1 + issue = overflow_issues[0] + assert issue["severity"] == "critical" + assert issue["evidence"]["growth_rate_per_turn"] == 22_000.0 + assert issue["evidence"]["estimated_overflow_turn"] is not None + + +def test_diagnostic_context_overflow_risk_warning(): + """Moderate utilization + near overflow triggers warning.""" + # 4 turns, 20K growth/turn, at 80K on 128K window + # utilization = 80000/128000 = 0.625 > 0.5 => warning + # overflow at turn 4 + int(48000/20000) = 4 + 2 = 6 + # 2x current = 8, overflow 6 <= 8 => triggers + session = Session(requests=[ + _request(0, 20_000), + _request(1, 40_000), + _request(2, 60_000), + _request(3, 80_000), + ]) + + result = profile_session(session) + diagnosis = diagnose_result(result, session=session) + + overflow_issues = [i for i in diagnosis["issues"] if i["code"] == "CONTEXT_OVERFLOW_RISK"] + assert len(overflow_issues) == 1 + assert overflow_issues[0]["severity"] == "warning" + + +def test_diagnostic_no_overflow_risk_when_far_away(): + """Low utilization with overflow far in the future should not trigger.""" + # 5 turns, 1K growth/turn, at 5K on 128K window + # overflow at turn 5 + int(123000/1000) = 5 + 123 = 128 + # 2x current = 10, overflow 128 > 10 => no trigger + session = Session(requests=[ + _request(0, 1_000), + _request(1, 2_000), + _request(2, 3_000), + _request(3, 4_000), + _request(4, 5_000), + ]) + + result = profile_session(session) + diagnosis = diagnose_result(result, session=session) + + overflow_issues = [i for i in diagnosis["issues"] if i["code"] == "CONTEXT_OVERFLOW_RISK"] + assert len(overflow_issues) == 0 + + +def test_diagnostic_no_overflow_risk_when_shrinking(): + """Shrinking context should never trigger overflow risk.""" + session = Session(requests=[ + _request(0, 80_000), + _request(1, 60_000), + _request(2, 40_000), + ]) + + result = profile_session(session) + diagnosis = diagnose_result(result, session=session) + + overflow_issues = [i for i in diagnosis["issues"] if i["code"] == "CONTEXT_OVERFLOW_RISK"] + assert len(overflow_issues) == 0 + + +def test_diagnostic_budget_forecast_in_session_insights(): + """Budget forecast should appear in the diagnosis JSON under session_insights.""" + session = Session(requests=[ + _request(0, 10_000), + _request(1, 20_000), + _request(2, 30_000), + ]) + + result = profile_session(session) + diagnosis = diagnose_result(result, session=session) + + assert "budget_forecast" in diagnosis["session_insights"] + forecast = diagnosis["session_insights"]["budget_forecast"] + assert forecast is not None + assert forecast["growth_rate_per_turn"] == 10_000.0