Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 124 additions & 7 deletions rag/pipelines/_cost_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import json
import logging
import os
from datetime import date as date_type
from typing import Any

Expand All @@ -43,6 +44,72 @@
_COST_BUCKET = "alpha-engine-research"
_COST_PREFIX = "decision_artifacts/_cost_raw"

# Phase 4 #1 — runaway-cost circuit breaker. Shared env var with
# alpha-engine-research's ``llm_cost_tracker.RunBudgetExceededError``
# so a single operator knob ceilings cost across all SF entry points.
_RUN_BUDGET_ENV_VAR = "ALPHA_ENGINE_RUN_BUDGET_USD"
_RUN_BUDGET_DEFAULT_USD = 100.0


def _resolve_run_budget_ceiling() -> float:
"""Read ``ALPHA_ENGINE_RUN_BUDGET_USD`` per-call (allows test toggling).

Mirrors ``alpha-engine-research/graph/llm_cost_tracker._resolve_run_budget_ceiling``
so the news-pipeline + research + executor all share the same operator
knob. Returns 0.0 on parse failure rather than raising — a malformed
env var shouldn't take down RAGIngestion; the parse-failure log is
loud enough that operators notice.

Returns a positive float to enforce the ceiling; zero or negative
disables enforcement entirely. Default $100 reflects the
workstream's "runaway prompt loop should fire well before the monthly
Anthropic bill" intent.
"""
raw = os.environ.get(_RUN_BUDGET_ENV_VAR, "")
if not raw:
return _RUN_BUDGET_DEFAULT_USD
try:
return float(raw)
except (TypeError, ValueError):
logger.warning(
"[cost_telemetry] ALPHA_ENGINE_RUN_BUDGET_USD=%r is not a "
"number; disabling run-budget enforcement (set to a positive "
"float to enable, 0 to explicitly disable)",
raw,
)
return 0.0


class CostBudgetExceededError(RuntimeError):
"""Raised mid-run when cumulative spend exceeds the configured
ceiling.

Per ``[[feedback_no_silent_fails]]`` — a runaway prompt loop should
kill the news pipeline before it bills the org into the next decade.
Surfaces ``run_id`` + cumulative cost + ceiling so operators map the
failure back to the offending SF run. Counterpart to research's
``RunBudgetExceededError`` (same env var, same default, same shape).
"""

def __init__(
self, *, run_id: str, agent_id: str,
cumulative_cost_usd: float, ceiling_usd: float,
) -> None:
self.run_id = run_id
self.agent_id = agent_id
self.cumulative_cost_usd = cumulative_cost_usd
self.ceiling_usd = ceiling_usd
super().__init__(
f"[cost_telemetry] run budget exceeded: "
f"run_id={run_id!r} agent_id={agent_id!r} "
f"cumulative_cost=${cumulative_cost_usd:.4f} > "
f"ceiling=${ceiling_usd:.4f}. Set "
f"ALPHA_ENGINE_RUN_BUDGET_USD=<higher_value> to raise the "
f"cap, or =0 to disable. Investigate the offending agent "
f"before raising the cap — a runaway prompt loop will keep "
f"growing."
)


class CostBufferFlushError(RuntimeError):
"""Raised when the S3 PutObject for the buffered cost rows fails.
Expand Down Expand Up @@ -70,12 +137,27 @@ def __init__(
agent_id: str,
bucket: str = _COST_BUCKET,
s3_client: Any | None = None,
ceiling_usd: float | None = None,
) -> None:
self._run_id = run_id
self._agent_id = agent_id
self._bucket = bucket
self._s3 = s3_client
# None = resolve from env at construction; explicit value =
# tests / operator-managed override. Resolving once at
# construction means a mid-run env-var change doesn't take
# effect until next pipeline invocation (matches research's
# ContextVar-per-run shape).
self._ceiling_usd = (
ceiling_usd if ceiling_usd is not None
else _resolve_run_budget_ceiling()
)
self._rows: list[dict] = []
self._cumulative_cost_usd: float = 0.0

@property
def cumulative_cost_usd(self) -> float:
return self._cumulative_cost_usd

def record(self, msg: Any) -> float:
"""Price ``msg``, append to buffer, return the row's USD cost.
Expand All @@ -84,6 +166,15 @@ def record(self, msg: Any) -> float:
with the buffer's ``run_id`` + ``agent_id`` stamped onto the
record's extra_fields so the daily aggregator's by-agent_id
breakdown surfaces this site's spend.

**Runaway-cost circuit breaker (Phase 4 #1):** raises
:exc:`CostBudgetExceededError` AFTER the row is recorded if
cumulative cost for this run exceeds
``ALPHA_ENGINE_RUN_BUDGET_USD`` (default $100). The row is
recorded first so per-call detail is preserved in the flush — operators
can inspect what broke the budget without re-running. Set
``ALPHA_ENGINE_RUN_BUDGET_USD=0`` (or pass ``ceiling_usd=0``) to
disable enforcement.
"""
record = record_anthropic_call(
msg,
Expand All @@ -93,7 +184,26 @@ def record(self, msg: Any) -> float:
},
)
self._rows.append(record)
return float(record["cost_usd"])
cost = float(record["cost_usd"])
self._cumulative_cost_usd += cost

if self._ceiling_usd > 0 and self._cumulative_cost_usd > self._ceiling_usd:
logger.error(
"[cost_telemetry] run budget exceeded for "
"run_id=%s agent_id=%s: cumulative=$%.4f > "
"ceiling=$%.4f (rows recorded=%d). Raising "
"CostBudgetExceededError to fail the run loud.",
self._run_id, self._agent_id,
self._cumulative_cost_usd, self._ceiling_usd,
len(self._rows),
)
raise CostBudgetExceededError(
run_id=self._run_id,
agent_id=self._agent_id,
cumulative_cost_usd=self._cumulative_cost_usd,
ceiling_usd=self._ceiling_usd,
)
return cost

@property
def row_count(self) -> int:
Expand Down Expand Up @@ -162,13 +272,20 @@ def create(self, *args, **kwargs):
response = self._wrapped.create(*args, **kwargs)
try:
self._buffer.record(response)
except CostBudgetExceededError:
# Runaway-cost circuit breaker fired — propagate. This IS
# the whole point of the breaker; swallowing it would defeat
# the safety net per [[feedback_no_silent_fails]]. The
# pipeline's outer try/finally flushes the buffer so all
# rows up to the breach are preserved on S3.
raise
except Exception as exc:
# Cost-telemetry failure must NOT bring down the producer
# (event extraction is the primary deliverable). Log loud +
# keep going. The flush step at pipeline exit still raises
# on S3 error per the no-silent-fails rule for the artifact
# write itself — per-call recording failures show up at flush
# time as a partial row count.
# Other cost-telemetry failures must NOT bring down the
# producer (event extraction is the primary deliverable).
# Log loud + keep going. The flush step at pipeline exit
# still raises on S3 error per the no-silent-fails rule for
# the artifact write itself — per-call recording failures
# show up at flush time as a partial row count.
logger.warning(
"[cost_telemetry] per-call recording failed: %s "
"(token counts NOT captured for this call; pipeline "
Expand Down
53 changes: 33 additions & 20 deletions rag/pipelines/run_news_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,26 +129,39 @@ def main() -> int:
logger.info("[run_news_pipeline] step 2/4 — NLP pipeline")
from rag.pipelines._cost_telemetry import build_news_cost_buffer
cost_buffer = build_news_cost_buffer(run_date=agg_date)
nlp_output = _run_nlp(articles, cost_buffer=cost_buffer)
logger.info(
"[run_news_pipeline] step 2 — sentiment_scores=%d "
"event_flags=%d entity_mentions=%d (%d/%d articles processed); "
"cost rows buffered=%d",
len(nlp_output.sentiment_scores),
len(nlp_output.event_flags),
len(nlp_output.entity_mentions),
nlp_output.n_articles_processed,
nlp_output.n_articles_processed + nlp_output.n_articles_failed,
cost_buffer.row_count,
)

# Flush cost-telemetry rows to S3. Per [[feedback_no_silent_fails]]
# the flush is hard-fail — a silent miss on the previously-dominant
# untracked cost slice would defeat the Phase 0 visibility goal.
# Pipeline-side dry-run + skip-nlp skip the flush by construction
# (buffer is None / empty).
if cost_buffer is not None and not args.dry_run:
cost_buffer.flush()
# try/finally: if the runaway-cost circuit breaker fires mid-loop
# (or any other exception inside _run_nlp), the flush still runs
# so rows up to the breach are preserved on S3. The breaker then
# re-raises and aborts the pipeline at the natural callsite.
try:
nlp_output = _run_nlp(articles, cost_buffer=cost_buffer)
logger.info(
"[run_news_pipeline] step 2 — sentiment_scores=%d "
"event_flags=%d entity_mentions=%d (%d/%d articles processed); "
"cost rows buffered=%d (cumulative=$%.4f)",
len(nlp_output.sentiment_scores),
len(nlp_output.event_flags),
len(nlp_output.entity_mentions),
nlp_output.n_articles_processed,
nlp_output.n_articles_processed + nlp_output.n_articles_failed,
cost_buffer.row_count,
cost_buffer.cumulative_cost_usd,
)
finally:
if cost_buffer is not None and not args.dry_run:
try:
cost_buffer.flush()
except Exception as flush_exc:
# On flush failure during exception unwind, log loud
# but don't shadow the original exception. Per
# [[feedback_no_silent_fails]] the row loss is
# operator-visible via the WARN; the original
# CostBudgetExceededError stays the failure-of-record.
logger.error(
"[run_news_pipeline] cost buffer flush FAILED "
"during exception unwind — rows LOST: %s",
flush_exc,
)

# ── Step 3: structured aggregates parquet ────────────────────
if args.dry_run:
Expand Down
111 changes: 111 additions & 0 deletions tests/test_news_cost_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import pytest

from rag.pipelines._cost_telemetry import (
CostBudgetExceededError,
CostBufferFlushError,
S3CostBuffer,
_resolve_run_budget_ceiling,
build_news_cost_buffer,
wrap_client_for_cost_telemetry,
)
Expand Down Expand Up @@ -241,3 +243,112 @@ def test_canonical_naming(self):
buf = build_news_cost_buffer(run_date=date(2026, 5, 25))
assert buf._run_id == "2026-05-25"
assert buf._agent_id == "data:news_event_extraction"


# ── Runaway-cost circuit breaker (Phase 4 #1) ────────────────────────────


class TestRunBudgetCeilingResolution:
def test_default_when_env_var_unset(self, monkeypatch):
monkeypatch.delenv("ALPHA_ENGINE_RUN_BUDGET_USD", raising=False)
assert _resolve_run_budget_ceiling() == 100.0

def test_positive_value_from_env(self, monkeypatch):
monkeypatch.setenv("ALPHA_ENGINE_RUN_BUDGET_USD", "5.50")
assert _resolve_run_budget_ceiling() == 5.50

def test_zero_disables_enforcement(self, monkeypatch):
monkeypatch.setenv("ALPHA_ENGINE_RUN_BUDGET_USD", "0")
assert _resolve_run_budget_ceiling() == 0.0

def test_malformed_env_var_returns_zero_not_raises(self, monkeypatch, caplog):
monkeypatch.setenv("ALPHA_ENGINE_RUN_BUDGET_USD", "not-a-number")
result = _resolve_run_budget_ceiling()
assert result == 0.0
assert any(
"is not a number" in r.message for r in caplog.records
)


class TestCostBudgetBreaker:
def test_under_ceiling_no_raise(self):
buf = S3CostBuffer(
run_id="2026-05-25", agent_id="data:news_event_extraction",
ceiling_usd=1.0,
)
# 1000 input + 200 output @ haiku-4-5 = $0.002 — well under $1.
cost = buf.record(_FakeMessage(
model="claude-haiku-4-5",
usage=_FakeUsage(input_tokens=1000, output_tokens=200),
))
assert cost == pytest.approx(0.002, abs=1e-6)
assert buf.cumulative_cost_usd == pytest.approx(0.002, abs=1e-6)

def test_breach_raises_after_recording_row(self):
"""Row is recorded BEFORE the raise so per-call detail is
preserved when the breaker fires. The buffer's flush() can then
write what was captured up to + including the breach call."""
buf = S3CostBuffer(
run_id="2026-05-25", agent_id="data:news_event_extraction",
ceiling_usd=0.001, # 0.1 cent — first call WILL exceed
)
with pytest.raises(CostBudgetExceededError) as exc_info:
buf.record(_FakeMessage(
model="claude-haiku-4-5",
usage=_FakeUsage(input_tokens=1000, output_tokens=200),
))
# Row was recorded (preserved for flush).
assert buf.row_count == 1
# Error carries enough context to map back to the offending run.
assert exc_info.value.run_id == "2026-05-25"
assert exc_info.value.agent_id == "data:news_event_extraction"
assert exc_info.value.cumulative_cost_usd == pytest.approx(0.002, abs=1e-6)
assert exc_info.value.ceiling_usd == 0.001
# Message tells operator how to adjust.
assert "ALPHA_ENGINE_RUN_BUDGET_USD" in str(exc_info.value)

def test_zero_ceiling_disables_enforcement(self):
buf = S3CostBuffer(
run_id="2026-05-25", agent_id="data:news_event_extraction",
ceiling_usd=0,
)
# 1B tokens would be impossible, but enforcement off → no raise.
# Use a plausible large call to keep the test honest.
for _ in range(100):
buf.record(_FakeMessage(
model="claude-haiku-4-5",
usage=_FakeUsage(input_tokens=10_000, output_tokens=2_000),
))
# Cumulative = 100 * (10000 * 1 + 2000 * 5) / 1M = 100 * 0.02 = 2.0
assert buf.cumulative_cost_usd == pytest.approx(2.0, abs=1e-6)
assert buf.row_count == 100

def test_proxy_propagates_breaker_does_not_swallow(self):
"""The proxy swallows generic record errors so event extraction
survives a malformed-response hiccup, but the runaway-cost
breaker MUST propagate so the safety net works."""
buf = S3CostBuffer(
run_id="2026-05-25", agent_id="data:news_event_extraction",
ceiling_usd=0.001,
)
underlying_client = MagicMock()
underlying_client.messages.create.return_value = _FakeMessage(
model="claude-haiku-4-5",
usage=_FakeUsage(input_tokens=1000, output_tokens=200),
)
wrapped = wrap_client_for_cost_telemetry(underlying_client, buf)
with pytest.raises(CostBudgetExceededError):
wrapped.messages.create(model="x", messages=[])

def test_ceiling_defaults_from_env(self, monkeypatch):
monkeypatch.setenv("ALPHA_ENGINE_RUN_BUDGET_USD", "0.0005")
buf = S3CostBuffer(
run_id="2026-05-25", agent_id="data:news_event_extraction",
# ceiling_usd not passed → resolves from env at construction
)
assert buf._ceiling_usd == 0.0005
with pytest.raises(CostBudgetExceededError):
buf.record(_FakeMessage(
model="claude-haiku-4-5",
usage=_FakeUsage(input_tokens=1000, output_tokens=200),
))
Loading