Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions evals/rolling_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,59 @@ def _list_metric_combos(
return combos


_GET_METRIC_DATA_MAX_QUERIES = 500
"""AWS hard cap on ``MetricDataQueries`` per ``GetMetricData`` call.

Exceeding this raises ``ValidationError: The collection
MetricDataQueries must not have a size greater than 500.`` As the
rubric × dimension matrix grew (more sector teams / sub-agents /
criteria / judge models) ``len(queries)`` crossed 500, which made the
Saturday-SF ``EvalRollingMean`` state return ``{"status":"ERROR"}``
(observed on the 2026-05-17 weekend run). ``_get_metric_data_all``
chunks below this cap and paginates each chunk."""


def _get_metric_data_all(
cw: Any,
queries: list[dict[str, Any]],
start: datetime,
end: datetime,
) -> list[dict[str, Any]]:
"""Run ``GetMetricData`` over arbitrarily many queries.

AWS caps ``MetricDataQueries`` at ``_GET_METRIC_DATA_MAX_QUERIES``
(500) per call, and each call may itself paginate via ``NextToken``.
This chunks ``queries`` into ≤500-query batches, follows
``NextToken`` to exhaustion within every chunk, and returns the
flat-merged ``MetricDataResults`` across all chunks/pages.

The flat merge is correct because every query Id (``m{idx}``, idx
over the whole ``combos`` list) is globally unique, so the
downstream ``by_id = {r["Id"]: r for r in ...}`` mapping stays
intact regardless of which chunk/page a result came back on. The
≤500 case is unchanged behaviourally: a single chunk, paginated
only if AWS returns a ``NextToken``.
"""
merged: list[dict[str, Any]] = []
for chunk_start in range(0, len(queries), _GET_METRIC_DATA_MAX_QUERIES):
chunk = queries[chunk_start:chunk_start + _GET_METRIC_DATA_MAX_QUERIES]
next_token: Optional[str] = None
while True:
kwargs: dict[str, Any] = {
"MetricDataQueries": chunk,
"StartTime": start,
"EndTime": end,
}
if next_token is not None:
kwargs["NextToken"] = next_token
response = cw.get_metric_data(**kwargs)
merged.extend(response.get("MetricDataResults", []))
next_token = response.get("NextToken")
if not next_token:
break
return merged


def _build_metric_data_queries(
combos: list[list[dict[str, str]]],
*,
Expand All @@ -105,8 +158,8 @@ def _build_metric_data_queries(
) -> list[dict[str, Any]]:
"""One GetMetricData query per combo. The query Id is the combo
index — used to map results back to dimensions on the response side.
CloudWatch caps queries at 500 per call; with ~6 sector teams ×
3 sub-agents × 4-6 criteria × 2 judge models = ~290 we're well below.
CloudWatch caps queries at 500 per call; ``_get_metric_data_all``
chunks + paginates so the matrix can grow past 500 combos safely.
"""
return [
{
Expand Down Expand Up @@ -193,18 +246,20 @@ def compute_and_emit_4w_mean(
len(combos), start.isoformat(), end.isoformat(),
)

response = cw.get_metric_data(
MetricDataQueries=queries,
StartTime=start,
EndTime=end,
)
# AWS caps MetricDataQueries at 500 per GetMetricData call; chunk +
# paginate so the rubric × dimension matrix can grow past 500
# combos without raising ValidationError (the 2026-05-17 Sat-SF
# EvalRollingMean ERROR root cause).
metric_data_results = _get_metric_data_all(cw, queries, start, end)

derived_data: list[dict[str, Any]] = []
skipped_no_data = 0
failed: list[dict[str, str]] = []

# Map response Id back to combo dimensions for the derived emission.
by_id = {r["Id"]: r for r in response.get("MetricDataResults", [])}
# Ids (m{idx}) are unique across the whole combos list, so a flat
# merge across chunks/pages is correct.
by_id = {r["Id"]: r for r in metric_data_results}

for idx, dims in enumerate(combos):
result = by_id.get(f"m{idx}")
Expand Down
128 changes: 128 additions & 0 deletions tests/test_eval_rolling_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,134 @@ def test_one_query_per_combo_with_indexed_id(self):
assert q["ReturnData"] is True


# ── _get_metric_data_all (chunk + paginate, AWS 500 cap) ──────────────────


class TestGetMetricDataAll:
"""The GetMetricData MetricDataQueries collection is hard-capped at
500 by AWS. As the rubric × dimension matrix grew past 500 combos
the single un-chunked call raised ``ValidationError: The collection
MetricDataQueries must not have a size greater than 500.`` — which
made the Sat-SF EvalRollingMean state ERROR on 2026-05-17. This
helper chunks ≤500 + follows NextToken pagination."""

def test_cap_constant_is_500(self):
from evals.rolling_mean import _GET_METRIC_DATA_MAX_QUERIES

assert _GET_METRIC_DATA_MAX_QUERIES == 500

def test_over_500_queries_split_into_multiple_calls_each_le_500(self):
from evals.rolling_mean import _get_metric_data_all

# 1201 queries → 3 chunks: 500, 500, 201.
queries = [{"Id": f"m{i}"} for i in range(1201)]
cw = MagicMock()
cw.get_metric_data.side_effect = lambda **kw: {
"MetricDataResults": [
{"Id": q["Id"], "Values": [1.0]}
for q in kw["MetricDataQueries"]
],
}

start = datetime(2026, 5, 1, tzinfo=timezone.utc)
end = datetime(2026, 5, 29, tzinfo=timezone.utc)
out = _get_metric_data_all(cw, queries, start, end)

# Three calls, every chunk ≤ 500.
assert cw.get_metric_data.call_count == 3
chunk_sizes = [
len(c.kwargs["MetricDataQueries"])
for c in cw.get_metric_data.call_args_list
]
assert chunk_sizes == [500, 500, 201]
assert all(n <= 500 for n in chunk_sizes)
# Every query's result merged exactly once, order preserved.
assert [r["Id"] for r in out] == [f"m{i}" for i in range(1201)]
# StartTime/EndTime threaded through unchanged on every call.
for c in cw.get_metric_data.call_args_list:
assert c.kwargs["StartTime"] == start
assert c.kwargs["EndTime"] == end
assert "NextToken" not in c.kwargs

def test_next_token_pagination_followed_within_chunk(self):
from evals.rolling_mean import _get_metric_data_all

queries = [{"Id": f"m{i}"} for i in range(3)] # one chunk
responses = [
{"MetricDataResults": [{"Id": "m0", "Values": [1.0]}],
"NextToken": "pg2"},
{"MetricDataResults": [{"Id": "m1", "Values": [2.0]}],
"NextToken": "pg3"},
{"MetricDataResults": [{"Id": "m2", "Values": [3.0]}]}, # no token → stop
]
cw = MagicMock()
cw.get_metric_data.side_effect = responses

start = datetime(2026, 5, 1, tzinfo=timezone.utc)
end = datetime(2026, 5, 29, tzinfo=timezone.utc)
out = _get_metric_data_all(cw, queries, start, end)

assert cw.get_metric_data.call_count == 3
# First call has no NextToken; subsequent calls carry the prior token.
calls = cw.get_metric_data.call_args_list
assert "NextToken" not in calls[0].kwargs
assert calls[1].kwargs["NextToken"] == "pg2"
assert calls[2].kwargs["NextToken"] == "pg3"
assert [r["Id"] for r in out] == ["m0", "m1", "m2"]

def test_le_500_single_call_unchanged_path(self):
from evals.rolling_mean import _get_metric_data_all

queries = [{"Id": f"m{i}"} for i in range(7)]
cw = MagicMock()
cw.get_metric_data.return_value = {
"MetricDataResults": [{"Id": q["Id"], "Values": [4.0]} for q in queries],
}

start = datetime(2026, 5, 1, tzinfo=timezone.utc)
end = datetime(2026, 5, 29, tzinfo=timezone.utc)
out = _get_metric_data_all(cw, queries, start, end)

# Exactly one call, no NextToken, all results returned.
assert cw.get_metric_data.call_count == 1
assert "NextToken" not in cw.get_metric_data.call_args.kwargs
assert len(out) == 7

def test_over_500_end_to_end_results_merged_back_to_combos(self):
"""compute_and_emit_4w_mean with >500 combos must split the
GetMetricData calls AND still map every result back to its
combo via the unique m{idx} Id scheme."""
from evals.rolling_mean import compute_and_emit_4w_mean

n = 1100
combos = [_dims(f"agent{i}", f"c{i}") for i in range(n)]
cw = MagicMock()
paginator = MagicMock()
paginator.paginate.return_value = [
{"Metrics": [{"Dimensions": d} for d in combos]},
]
cw.get_paginator.return_value = paginator
# Echo back one result per query in the chunk.
cw.get_metric_data.side_effect = lambda **kw: {
"MetricDataResults": [
{"Id": q["Id"], "Values": [3.5]}
for q in kw["MetricDataQueries"]
],
}
s3 = MagicMock()

result = compute_and_emit_4w_mean(cloudwatch_client=cw, s3_client=s3)

# 1100 combos → 3 GetMetricData calls (500 + 500 + 100), each ≤ 500.
assert cw.get_metric_data.call_count == 3
for c in cw.get_metric_data.call_args_list:
assert len(c.kwargs["MetricDataQueries"]) <= 500
# All combos mapped back + emitted (no "missing result" failures).
assert result["combos_discovered"] == n
assert result["datapoints_emitted"] == n
assert result["failed"] == []


# ── compute_and_emit_4w_mean ──────────────────────────────────────────────


Expand Down
Loading