diff --git a/src/skillspector/llm_analyzer_base.py b/src/skillspector/llm_analyzer_base.py index c5ab9dce..8da8beba 100644 --- a/src/skillspector/llm_analyzer_base.py +++ b/src/skillspector/llm_analyzer_base.py @@ -30,7 +30,7 @@ import asyncio from collections import defaultdict from dataclasses import dataclass, field -from typing import Literal +from typing import Literal, TypedDict from langchain_core.messages import BaseMessage from pydantic import BaseModel, Field, field_validator @@ -114,6 +114,38 @@ class LLMAnalysisResult(BaseModel): findings: list[LLMFinding] = Field(default_factory=list) +class LLMTokenUsage(TypedDict): + """Provider-normalized token usage for LLM calls.""" + + input_tokens: int + output_tokens: int + total_tokens: int + + +def _empty_token_usage() -> LLMTokenUsage: + return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + +def _extract_token_usage(raw: object) -> LLMTokenUsage: + usage = getattr(raw, "usage_metadata", None) or {} + if not isinstance(usage, dict): + return _empty_token_usage() + input_tokens = int(usage.get("input_tokens") or usage.get("prompt_tokens") or 0) + output_tokens = int(usage.get("output_tokens") or usage.get("completion_tokens") or 0) + total_tokens = int(usage.get("total_tokens") or input_tokens + output_tokens) + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + + +def _add_token_usage(total: LLMTokenUsage, usage: LLMTokenUsage) -> None: + total["input_tokens"] += usage["input_tokens"] + total["output_tokens"] += usage["output_tokens"] + total["total_tokens"] += usage["total_tokens"] + + def estimate_tokens(text: str) -> int: """Approximate token count from character length.""" return len(text) // CHARS_PER_TOKEN @@ -275,8 +307,36 @@ def __init__(self, base_prompt: str, model: str): self._input_budget = get_max_input_tokens(model) self._llm = get_chat_model(model=model) self._structured_llm = ( - self._llm.with_structured_output(self.response_schema) if self.response_schema else None + self._llm.with_structured_output(self.response_schema, include_raw=True) + if self.response_schema + else None ) + self._llm_usage = _empty_token_usage() + + @property + def llm_usage(self) -> LLMTokenUsage: + """Cumulative token usage from the most recent batch run.""" + return dict(self._llm_usage) # type: ignore[return-value] + + def _reset_llm_usage(self) -> None: + self._llm_usage = _empty_token_usage() + + def _record_usage_from_raw(self, raw: object) -> None: + _add_token_usage(self._llm_usage, _extract_token_usage(raw)) + + def _unwrap_structured_response(self, response: object) -> object: + if not isinstance(response, dict) or not {"raw", "parsed", "parsing_error"} <= set( + response + ): + return response + raw = response.get("raw") + self._record_usage_from_raw(raw) + parsing_error = response.get("parsing_error") + if parsing_error is not None: + if isinstance(parsing_error, BaseException): + raise parsing_error + raise ValueError(str(parsing_error)) + return response.get("parsed") # -- Batching ----------------------------------------------------------- @@ -376,6 +436,7 @@ def run_batches( :meth:`parse_response` returns :class:`Finding` objects; subclasses may return dicts or other types. """ + self._reset_llm_usage() results: list[tuple[Batch, list]] = [] for batch in batches: prompt = self.build_prompt(batch, **kwargs) @@ -386,9 +447,11 @@ def run_batches( len(batch.findings), ) if self._structured_llm: - response = self._structured_llm.invoke(prompt) + response = self._unwrap_structured_response(self._structured_llm.invoke(prompt)) else: - response = _message_text(self._llm.invoke(prompt)) + raw_response = self._llm.invoke(prompt) + self._record_usage_from_raw(raw_response) + response = _message_text(raw_response) logger.debug("LLM response for %s", batch.file_label) parsed = self.parse_response(response, batch) results.append((batch, parsed)) @@ -417,6 +480,7 @@ async def arun_batches( The return type mirrors :meth:`run_batches`. """ + self._reset_llm_usage() sem = asyncio.Semaphore(max_concurrency) async def _process(batch: Batch) -> tuple[Batch, list]: @@ -429,9 +493,13 @@ async def _process(batch: Batch) -> tuple[Batch, list]: len(batch.findings), ) if self._structured_llm: - response = await self._structured_llm.ainvoke(prompt) + response = self._unwrap_structured_response( + await self._structured_llm.ainvoke(prompt) + ) else: - response = _message_text(await self._llm.ainvoke(prompt)) + raw_response = await self._llm.ainvoke(prompt) + self._record_usage_from_raw(raw_response) + response = _message_text(raw_response) logger.debug("LLM response for %s", batch.file_label) return (batch, self.parse_response(response, batch)) diff --git a/src/skillspector/llm_utils.py b/src/skillspector/llm_utils.py index 468e26b0..145bc2df 100644 --- a/src/skillspector/llm_utils.py +++ b/src/skillspector/llm_utils.py @@ -39,6 +39,7 @@ from typing import NoReturn from langchain_core.language_models.chat_models import BaseChatModel +from pydantic import BaseModel from skillspector.model_info import get_max_input_tokens, get_max_output_tokens from skillspector.providers import ( @@ -161,11 +162,19 @@ class _StructuredAgentCLIModel: ``complete()``, then parses and validates the response into *schema*. """ - def __init__(self, provider: object, model: str, max_output_tokens: int, schema: type) -> None: + def __init__( + self, + provider: object, + model: str, + max_output_tokens: int, + schema: type[BaseModel], + include_raw: bool = False, + ) -> None: self._provider = provider self._model = model self._max_output_tokens = max_output_tokens self._schema = schema + self._include_raw = include_raw def _augment(self, prompt: str) -> str: schema_json = json.dumps(self._schema.model_json_schema(), indent=2) @@ -182,7 +191,17 @@ def invoke(self, prompt: str) -> object: model=self._model, max_output_tokens=self._max_output_tokens, ) - return self._schema.model_validate(_extract_json_object(raw)) + try: + parsed = self._schema.model_validate(_extract_json_object(raw)) + parsing_error = None + except Exception as exc: + parsed = None + parsing_error = exc + if not self._include_raw: + raise + if self._include_raw: + return {"raw": _AgentCLIMessage(raw), "parsed": parsed, "parsing_error": parsing_error} + return parsed async def ainvoke(self, prompt: str) -> object: return await asyncio.to_thread(self.invoke, prompt) @@ -227,9 +246,11 @@ def invoke(self, prompt: str) -> _AgentCLIMessage: async def ainvoke(self, prompt: str) -> _AgentCLIMessage: return await asyncio.to_thread(self.invoke, prompt) - def with_structured_output(self, schema: type) -> _StructuredAgentCLIModel: + def with_structured_output( + self, schema: type[BaseModel], *, include_raw: bool = False + ) -> _StructuredAgentCLIModel: return _StructuredAgentCLIModel( - self._provider, self._model, self._max_output_tokens, schema + self._provider, self._model, self._max_output_tokens, schema, include_raw=include_raw ) diff --git a/src/skillspector/nodes/analyzers/semantic_developer_intent.py b/src/skillspector/nodes/analyzers/semantic_developer_intent.py index f51fe8f0..b2279f7a 100644 --- a/src/skillspector/nodes/analyzers/semantic_developer_intent.py +++ b/src/skillspector/nodes/analyzers/semantic_developer_intent.py @@ -32,6 +32,11 @@ ANALYZER_ID = "semantic_developer_intent" logger = get_logger(__name__) + +class _NoUsage: + llm_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + ANALYZER_PROMPT = """\ You are a developer-intent auditor for AI agent skills. Your job is to detect mismatches between what a skill *claims* to do (its manifest and @@ -179,12 +184,22 @@ def node(state: SkillspectorState) -> AnalyzerNodeResponse: results = asyncio.run(analyzer.arun_batches(batches)) findings = analyzer.collect_findings(results) logger.info("%s: %d findings", ANALYZER_ID, len(findings)) - return {"findings": findings, "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True)]} + return { + "findings": findings, + "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True, **analyzer.llm_usage)], + } except ValueError: raise except Exception as exc: logger.warning("%s failed: %s", ANALYZER_ID, exc) return { "findings": [], - "llm_call_log": [llm_call_record(ANALYZER_ID, ok=False, error=str(exc))], + "llm_call_log": [ + llm_call_record( + ANALYZER_ID, + ok=False, + error=str(exc), + **locals().get("analyzer", _NoUsage()).llm_usage, + ) + ], } diff --git a/src/skillspector/nodes/analyzers/semantic_quality_policy.py b/src/skillspector/nodes/analyzers/semantic_quality_policy.py index 18b48486..16940ff5 100644 --- a/src/skillspector/nodes/analyzers/semantic_quality_policy.py +++ b/src/skillspector/nodes/analyzers/semantic_quality_policy.py @@ -32,6 +32,11 @@ ANALYZER_ID = "semantic_quality_policy" logger = get_logger(__name__) + +class _NoUsage: + llm_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + ANALYZER_PROMPT = """\ You are a quality and safety auditor for AI agent skills. Your job is to review a single skill file and report findings that fall into the categories @@ -148,12 +153,22 @@ def node(state: SkillspectorState) -> AnalyzerNodeResponse: results = asyncio.run(analyzer.arun_batches(batches)) findings = analyzer.collect_findings(results) logger.info("%s: %d findings", ANALYZER_ID, len(findings)) - return {"findings": findings, "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True)]} + return { + "findings": findings, + "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True, **analyzer.llm_usage)], + } except ValueError: raise except Exception as exc: logger.warning("%s failed: %s", ANALYZER_ID, exc) return { "findings": [], - "llm_call_log": [llm_call_record(ANALYZER_ID, ok=False, error=str(exc))], + "llm_call_log": [ + llm_call_record( + ANALYZER_ID, + ok=False, + error=str(exc), + **locals().get("analyzer", _NoUsage()).llm_usage, + ) + ], } diff --git a/src/skillspector/nodes/analyzers/semantic_security_discovery.py b/src/skillspector/nodes/analyzers/semantic_security_discovery.py index 72a0dde1..8c7f18bd 100644 --- a/src/skillspector/nodes/analyzers/semantic_security_discovery.py +++ b/src/skillspector/nodes/analyzers/semantic_security_discovery.py @@ -27,6 +27,11 @@ ANALYZER_ID = "semantic_security_discovery" logger = get_logger(__name__) + +class _NoUsage: + llm_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + ANALYZER_PROMPT = """\ You are a security analyzer for AI agent skill files. Your task is to identify \ **intent and attack-phrasing risks** — issues that evade regex/static detection because \ @@ -90,14 +95,22 @@ def node(state: SkillspectorState) -> AnalyzerNodeResponse: results = analyzer.run_batches(batches) findings = analyzer.collect_findings(results) logger.info("%s: %d findings", ANALYZER_ID, len(findings)) - return {"findings": findings, "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True)]} + return { + "findings": findings, + "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True, **analyzer.llm_usage)], + } except ValidationError as exc: # Malformed LLM response — degrade gracefully rather than crashing the graph logger.warning("%s: LLM returned malformed response: %s", ANALYZER_ID, exc) return { "findings": [], "llm_call_log": [ - llm_call_record(ANALYZER_ID, ok=False, error=f"malformed LLM response: {exc}") + llm_call_record( + ANALYZER_ID, + ok=False, + error=f"malformed LLM response: {exc}", + **locals().get("analyzer", _NoUsage()).llm_usage, + ) ], } except ValueError: @@ -106,5 +119,12 @@ def node(state: SkillspectorState) -> AnalyzerNodeResponse: logger.warning("%s failed: %s", ANALYZER_ID, exc) return { "findings": [], - "llm_call_log": [llm_call_record(ANALYZER_ID, ok=False, error=str(exc))], + "llm_call_log": [ + llm_call_record( + ANALYZER_ID, + ok=False, + error=str(exc), + **locals().get("analyzer", _NoUsage()).llm_usage, + ) + ], } diff --git a/src/skillspector/nodes/meta_analyzer.py b/src/skillspector/nodes/meta_analyzer.py index 58c5b634..57a773e8 100644 --- a/src/skillspector/nodes/meta_analyzer.py +++ b/src/skillspector/nodes/meta_analyzer.py @@ -44,6 +44,10 @@ logger = get_logger(__name__) +class _NoUsage: + llm_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + # --------------------------------------------------------------------------- # Structured output schemas # --------------------------------------------------------------------------- @@ -565,7 +569,7 @@ def meta_analyzer(state: SkillspectorState) -> MetaAnalyzerResponse: ) return { "filtered_findings": filtered, - "llm_call_log": [llm_call_record("meta_analyzer", ok=True)], + "llm_call_log": [llm_call_record("meta_analyzer", ok=True, **analyzer.llm_usage)], } except ValueError: raise @@ -573,5 +577,12 @@ def meta_analyzer(state: SkillspectorState) -> MetaAnalyzerResponse: logger.warning("LLM call failed, passing all findings through (fail-closed): %s", e) return { "filtered_findings": _passthrough_with_defaults(findings), - "llm_call_log": [llm_call_record("meta_analyzer", ok=False, error=str(e))], + "llm_call_log": [ + llm_call_record( + "meta_analyzer", + ok=False, + error=str(e), + **locals().get("analyzer", _NoUsage()).llm_usage, + ) + ], } diff --git a/src/skillspector/nodes/report.py b/src/skillspector/nodes/report.py index 95160398..639b8c0a 100644 --- a/src/skillspector/nodes/report.py +++ b/src/skillspector/nodes/report.py @@ -23,6 +23,7 @@ import json import re +from collections.abc import Mapping, Sequence from dataclasses import replace from datetime import UTC, datetime from io import StringIO @@ -125,7 +126,7 @@ def _severity_to_sarif_level(severity: str) -> Literal["error", "warning", "note def _compute_risk_score( findings: list[Finding], has_executable_scripts: bool, - component_metadata: list[dict[str, object]] | None = None, + component_metadata: Sequence[Mapping[str, object]] | None = None, ) -> tuple[int, str, str]: """ Compute risk score (0-100), severity band, and recommendation. @@ -307,7 +308,7 @@ def _build_sarif( def _format_terminal( findings: list[Finding], - component_metadata: list[dict[str, object]], + component_metadata: Sequence[Mapping[str, object]], manifest: dict[str, object], skill_path: str | None, risk_score: int, @@ -315,7 +316,7 @@ def _format_terminal( risk_recommendation: str, has_executable_scripts: bool, use_llm: bool = True, - llm_call_log: list[dict[str, object]] | None = None, + llm_call_log: Sequence[Mapping[str, object]] | None = None, suppressed: list[SuppressedFinding] | None = None, show_suppressed: bool = False, ) -> str: @@ -422,7 +423,7 @@ def _format_terminal( def _llm_runtime_status( - use_llm: bool, llm_call_log: list[dict[str, object]] + use_llm: bool, llm_call_log: Sequence[Mapping[str, object]] ) -> tuple[int, int, bool]: """Return ``(attempted, succeeded, degraded)`` from the LLM call log. @@ -436,7 +437,9 @@ def _llm_runtime_status( return attempted, succeeded, degraded -def _llm_degradation_notice(use_llm: bool, llm_call_log: list[dict[str, object]]) -> str | None: +def _llm_degradation_notice( + use_llm: bool, llm_call_log: Sequence[Mapping[str, object]] +) -> str | None: """Return a human-readable degraded-scan warning, or None if not degraded.""" attempted, _succeeded, degraded = _llm_runtime_status(use_llm, llm_call_log) if not degraded: @@ -447,10 +450,27 @@ def _llm_degradation_notice(use_llm: bool, llm_call_log: list[dict[str, object]] ) +def _usage_int(value: object) -> int: + if isinstance(value, int): + return value + if isinstance(value, str) and value.isdecimal(): + return int(value) + return 0 + + +def _aggregate_llm_usage(llm_call_log: Sequence[Mapping[str, object]]) -> dict[str, int]: + """Aggregate provider-reported token usage across LLM call records.""" + return { + "input_tokens": sum(_usage_int(r.get("input_tokens")) for r in llm_call_log), + "output_tokens": sum(_usage_int(r.get("output_tokens")) for r in llm_call_log), + "total_tokens": sum(_usage_int(r.get("total_tokens")) for r in llm_call_log), + } + + def _build_metadata( has_executable_scripts: bool, use_llm: bool, - llm_call_log: list[dict[str, object]] | None = None, + llm_call_log: Sequence[Mapping[str, object]] | None = None, ) -> dict[str, object]: """Build the metadata section shared by all output formats.""" llm_call_log = llm_call_log or [] @@ -468,6 +488,7 @@ def _build_metadata( # available AND the stage was not fully degraded (every call failing). "llm_available": llm_available and not degraded, "meta_analysis_applied": meta_analysis_applied, + "llm_usage": _aggregate_llm_usage(llm_call_log), } if not meta_analysis_applied: meta["filtering_mode"] = "heuristic" @@ -537,7 +558,7 @@ def _build_analysis_completeness( def _format_json( findings: list[Finding], - component_metadata: list[dict[str, object]], + component_metadata: Sequence[Mapping[str, object]], manifest: dict[str, object], skill_path: str | None, risk_score: int, @@ -545,7 +566,7 @@ def _format_json( risk_recommendation: str, has_executable_scripts: bool, use_llm: bool = True, - llm_call_log: list[dict[str, object]] | None = None, + llm_call_log: Sequence[Mapping[str, object]] | None = None, analysis_completeness: dict[str, object] | None = None, suppressed: list[SuppressedFinding] | None = None, ) -> str: @@ -585,7 +606,7 @@ def _format_json( def _format_markdown( findings: list[Finding], - component_metadata: list[dict[str, object]], + component_metadata: Sequence[Mapping[str, object]], manifest: dict[str, object], skill_path: str | None, risk_score: int, @@ -593,7 +614,7 @@ def _format_markdown( risk_recommendation: str, has_executable_scripts: bool, use_llm: bool = True, - llm_call_log: list[dict[str, object]] | None = None, + llm_call_log: Sequence[Mapping[str, object]] | None = None, suppressed: list[SuppressedFinding] | None = None, show_suppressed: bool = False, ) -> str: diff --git a/src/skillspector/state.py b/src/skillspector/state.py index 68d41d91..4d403c5b 100644 --- a/src/skillspector/state.py +++ b/src/skillspector/state.py @@ -97,9 +97,20 @@ class LLMCallRecord(TypedDict): node: str ok: bool error: str | None - - -def llm_call_record(node_id: str, *, ok: bool, error: str | None = None) -> LLMCallRecord: + input_tokens: int + output_tokens: int + total_tokens: int + + +def llm_call_record( + node_id: str, + *, + ok: bool, + error: str | None = None, + input_tokens: int = 0, + output_tokens: int = 0, + total_tokens: int = 0, +) -> LLMCallRecord: """Build one telemetry record for ``SkillspectorState['llm_call_log']``. LLM-backed nodes append a record on each run so the report can tell whether @@ -107,7 +118,14 @@ def llm_call_record(node_id: str, *, ok: bool, error: str | None = None) -> LLMC failure where the node fell back to empty/static findings (so the failure is not mistaken for "the LLM ran and found nothing"). """ - return {"node": node_id, "ok": ok, "error": error} + return { + "node": node_id, + "ok": ok, + "error": error, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } class AnalyzerNodeResponse(TypedDict): diff --git a/tests/nodes/analyzers/test_semantic_developer_intent.py b/tests/nodes/analyzers/test_semantic_developer_intent.py index 90180ad0..08579997 100644 --- a/tests/nodes/analyzers/test_semantic_developer_intent.py +++ b/tests/nodes/analyzers/test_semantic_developer_intent.py @@ -30,6 +30,7 @@ _format_manifest, node, ) +from skillspector.state import llm_call_record MOCK_PATCH_TARGET = "skillspector.llm_analyzer_base.get_chat_model" @@ -242,7 +243,7 @@ def test_success_records_ok_true(self) -> None: with patch.object(LLMAnalyzerBase, "arun_batches", new_callable=AsyncMock, return_value=[]): result = node({"file_cache": {"main.py": "import os"}}) - assert result["llm_call_log"] == [{"node": ANALYZER_ID, "ok": True, "error": None}] + assert result["llm_call_log"] == [llm_call_record(ANALYZER_ID, ok=True)] @patch(MOCK_PATCH_TARGET) def test_exception_records_ok_false(self, mock_get_model: MagicMock) -> None: diff --git a/tests/nodes/analyzers/test_semantic_security_discovery.py b/tests/nodes/analyzers/test_semantic_security_discovery.py index ec77aded..30bc1ab6 100644 --- a/tests/nodes/analyzers/test_semantic_security_discovery.py +++ b/tests/nodes/analyzers/test_semantic_security_discovery.py @@ -30,6 +30,7 @@ ANALYZER_PROMPT, node, ) +from skillspector.state import llm_call_record # --------------------------------------------------------------------------- # Shared helpers @@ -318,7 +319,7 @@ def test_success_records_ok_true(self, base_state) -> None: with patch.object(LLMAnalyzerBase, "run_batches", return_value=[]): result = node(base_state) - assert result["llm_call_log"] == [{"node": ANALYZER_ID, "ok": True, "error": None}] + assert result["llm_call_log"] == [llm_call_record(ANALYZER_ID, ok=True)] @patch(MOCK_PATCH_TARGET) def test_generic_exception_records_ok_false(self, mock_get_model: MagicMock) -> None: diff --git a/tests/nodes/test_llm_analyzer_base.py b/tests/nodes/test_llm_analyzer_base.py index e344e654..6e7ed780 100644 --- a/tests/nodes/test_llm_analyzer_base.py +++ b/tests/nodes/test_llm_analyzer_base.py @@ -1567,6 +1567,66 @@ def test_llm_unconfirmed_tag_surfaced_in_to_dict(self) -> None: # --------------------------------------------------------------------------- + + +class _RawWithUsage: + def __init__(self, usage_metadata: dict[str, int] | None = None) -> None: + self.usage_metadata = usage_metadata + + +@patch("skillspector.llm_analyzer_base.get_chat_model") +def test_run_batches_records_token_usage_with_include_raw(mock_get_model: MagicMock) -> None: + mock_llm = MagicMock() + mock_structured = MagicMock() + mock_llm.with_structured_output.return_value = mock_structured + mock_get_model.return_value = mock_llm + response = LLMAnalysisResult(findings=[]) + mock_structured.invoke.return_value = { + "raw": _RawWithUsage({"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), + "parsed": response, + "parsing_error": None, + } + analyzer = LLMAnalyzerBase(base_prompt="Analyze", model="test-model") + results = analyzer.run_batches([Batch(file_path="a.py", content="code")]) + assert results == [(results[0][0], [])] + mock_llm.with_structured_output.assert_called_once_with(LLMAnalysisResult, include_raw=True) + assert analyzer.llm_usage == {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + + +@patch("skillspector.llm_analyzer_base.get_chat_model") +async def test_arun_batches_records_token_usage_with_include_raw(mock_get_model: MagicMock) -> None: + mock_llm = MagicMock() + mock_structured = MagicMock() + mock_structured.ainvoke = AsyncMock( + return_value={ + "raw": _RawWithUsage({"prompt_tokens": 7, "completion_tokens": 3}), + "parsed": LLMAnalysisResult(findings=[]), + "parsing_error": None, + } + ) + mock_llm.with_structured_output.return_value = mock_structured + mock_get_model.return_value = mock_llm + analyzer = LLMAnalyzerBase(base_prompt="Analyze", model="test-model") + await analyzer.arun_batches([Batch(file_path="a.py", content="code")]) + assert analyzer.llm_usage == {"input_tokens": 7, "output_tokens": 3, "total_tokens": 10} + + +@patch("skillspector.llm_analyzer_base.get_chat_model") +def test_run_batches_missing_usage_metadata_defaults_to_zero(mock_get_model: MagicMock) -> None: + mock_llm = MagicMock() + mock_structured = MagicMock() + mock_structured.invoke.return_value = { + "raw": object(), + "parsed": LLMAnalysisResult(findings=[]), + "parsing_error": None, + } + mock_llm.with_structured_output.return_value = mock_structured + mock_get_model.return_value = mock_llm + analyzer = LLMAnalyzerBase(base_prompt="Analyze", model="test-model") + analyzer.run_batches([Batch(file_path="a.py", content="code")]) + assert analyzer.llm_usage == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + # LLMMetaAnalyzer.run_batches (mocked LLM) # --------------------------------------------------------------------------- diff --git a/tests/nodes/test_meta_analyzer.py b/tests/nodes/test_meta_analyzer.py index 7eea0448..7da1a2ff 100644 --- a/tests/nodes/test_meta_analyzer.py +++ b/tests/nodes/test_meta_analyzer.py @@ -27,7 +27,7 @@ from skillspector.llm_analyzer_base import Batch from skillspector.models import Finding from skillspector.nodes.meta_analyzer import LLMMetaAnalyzer, meta_analyzer -from skillspector.state import SkillspectorState +from skillspector.state import SkillspectorState, llm_call_record MOCK_PATCH_TARGET = "skillspector.llm_analyzer_base.get_chat_model" @@ -270,7 +270,7 @@ def test_records_ok_true_on_success() -> None: ), ): result = meta_analyzer(_degr_state()) - assert result["llm_call_log"] == [{"node": "meta_analyzer", "ok": True, "error": None}] + assert result["llm_call_log"] == [llm_call_record("meta_analyzer", ok=True)] def test_construction_failure_is_caught_not_raised() -> None: diff --git a/tests/nodes/test_report.py b/tests/nodes/test_report.py index 91195003..4b7e9a12 100644 --- a/tests/nodes/test_report.py +++ b/tests/nodes/test_report.py @@ -26,6 +26,7 @@ _DIMINISHING_WEIGHTS, _MAX_OCCURRENCES_PER_RULE, _SEVERITY_POINTS, + _build_metadata, _compute_risk_score, report, ) @@ -914,3 +915,21 @@ def test_report_doc_findings_no_multiplier() -> None: # Without the multiplier: 2 HIGH = 50, not 65 assert result["risk_score"] == 50 assert result["risk_severity"] == "MEDIUM" + + +def test_build_metadata_aggregates_llm_usage() -> None: + metadata = _build_metadata( + has_executable_scripts=False, + use_llm=True, + llm_call_log=[ + llm_call_record("a", ok=True, input_tokens=10, output_tokens=5, total_tokens=15), + llm_call_record( + "b", ok=False, error="boom", input_tokens=2, output_tokens=1, total_tokens=3 + ), + ], + ) + assert metadata["llm_usage"] == { + "input_tokens": 12, + "output_tokens": 6, + "total_tokens": 18, + } diff --git a/tests/nodes/test_semantic_quality_policy.py b/tests/nodes/test_semantic_quality_policy.py index d0e69cc4..e1d6c432 100644 --- a/tests/nodes/test_semantic_quality_policy.py +++ b/tests/nodes/test_semantic_quality_policy.py @@ -29,6 +29,7 @@ ANALYZER_PROMPT, node, ) +from skillspector.state import llm_call_record # --------------------------------------------------------------------------- # Shared mocks @@ -271,7 +272,7 @@ def test_success_records_ok_true(self) -> None: with patch.object(LLMAnalyzerBase, "arun_batches", new_callable=AsyncMock, return_value=[]): result = node({"file_cache": {"SKILL.md": "# Skill"}}) - assert result["llm_call_log"] == [{"node": ANALYZER_ID, "ok": True, "error": None}] + assert result["llm_call_log"] == [llm_call_record(ANALYZER_ID, ok=True)] @patch(MOCK_PATCH_TARGET) def test_exception_records_ok_false(self, mock_get_model: MagicMock) -> None: diff --git a/tests/test_mcp_tool_poisoning.py b/tests/test_mcp_tool_poisoning.py index 0754666c..d1ba1bb7 100644 --- a/tests/test_mcp_tool_poisoning.py +++ b/tests/test_mcp_tool_poisoning.py @@ -25,6 +25,7 @@ import yaml from skillspector.nodes.analyzers import mcp_tool_poisoning +from skillspector.state import llm_call_record # --------------------------------------------------------------------------- # Fixture directory path @@ -681,7 +682,7 @@ def test_successful_call_records_ok_true(self): return_value='{"is_mismatch": false}', ): result = node(state) - assert result["llm_call_log"] == [{"node": "mcp_tool_poisoning", "ok": True, "error": None}] + assert result["llm_call_log"] == [llm_call_record("mcp_tool_poisoning", ok=True)] def test_failed_call_records_ok_false(self): from unittest.mock import patch diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 00000000..75adab5e --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,12 @@ +from skillspector.state import llm_call_record + + +def test_llm_call_record_includes_token_fields() -> None: + assert llm_call_record("node", ok=True) == { + "node": "node", + "ok": True, + "error": None, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + }