Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 74 additions & 6 deletions src/skillspector/llm_analyzer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Literal
from typing import Literal, TypedDict

from langchain_core.messages import BaseMessage
from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -114,6 +114,38 @@ class LLMAnalysisResult(BaseModel):
findings: list[LLMFinding] = Field(default_factory=list)


class LLMTokenUsage(TypedDict):
"""Provider-normalized token usage for LLM calls."""

input_tokens: int
output_tokens: int
total_tokens: int


def _empty_token_usage() -> LLMTokenUsage:
return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}


def _extract_token_usage(raw: object) -> LLMTokenUsage:
usage = getattr(raw, "usage_metadata", None) or {}
if not isinstance(usage, dict):
return _empty_token_usage()
input_tokens = int(usage.get("input_tokens") or usage.get("prompt_tokens") or 0)
output_tokens = int(usage.get("output_tokens") or usage.get("completion_tokens") or 0)
total_tokens = int(usage.get("total_tokens") or input_tokens + output_tokens)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}


def _add_token_usage(total: LLMTokenUsage, usage: LLMTokenUsage) -> None:
total["input_tokens"] += usage["input_tokens"]
total["output_tokens"] += usage["output_tokens"]
total["total_tokens"] += usage["total_tokens"]


def estimate_tokens(text: str) -> int:
"""Approximate token count from character length."""
return len(text) // CHARS_PER_TOKEN
Expand Down Expand Up @@ -275,8 +307,36 @@ def __init__(self, base_prompt: str, model: str):
self._input_budget = get_max_input_tokens(model)
self._llm = get_chat_model(model=model)
self._structured_llm = (
self._llm.with_structured_output(self.response_schema) if self.response_schema else None
self._llm.with_structured_output(self.response_schema, include_raw=True)
if self.response_schema
else None
)
self._llm_usage = _empty_token_usage()

@property
def llm_usage(self) -> LLMTokenUsage:
"""Cumulative token usage from the most recent batch run."""
return dict(self._llm_usage) # type: ignore[return-value]

def _reset_llm_usage(self) -> None:
self._llm_usage = _empty_token_usage()

def _record_usage_from_raw(self, raw: object) -> None:
_add_token_usage(self._llm_usage, _extract_token_usage(raw))

def _unwrap_structured_response(self, response: object) -> object:
if not isinstance(response, dict) or not {"raw", "parsed", "parsing_error"} <= set(
response
):
return response
raw = response.get("raw")
self._record_usage_from_raw(raw)
parsing_error = response.get("parsing_error")
if parsing_error is not None:
if isinstance(parsing_error, BaseException):
raise parsing_error
raise ValueError(str(parsing_error))
return response.get("parsed")

# -- Batching -----------------------------------------------------------

Expand Down Expand Up @@ -376,6 +436,7 @@ def run_batches(
:meth:`parse_response` returns :class:`Finding` objects; subclasses may
return dicts or other types.
"""
self._reset_llm_usage()
results: list[tuple[Batch, list]] = []
for batch in batches:
prompt = self.build_prompt(batch, **kwargs)
Expand All @@ -386,9 +447,11 @@ def run_batches(
len(batch.findings),
)
if self._structured_llm:
response = self._structured_llm.invoke(prompt)
response = self._unwrap_structured_response(self._structured_llm.invoke(prompt))
else:
response = _message_text(self._llm.invoke(prompt))
raw_response = self._llm.invoke(prompt)
self._record_usage_from_raw(raw_response)
response = _message_text(raw_response)
logger.debug("LLM response for %s", batch.file_label)
parsed = self.parse_response(response, batch)
results.append((batch, parsed))
Expand Down Expand Up @@ -417,6 +480,7 @@ async def arun_batches(

The return type mirrors :meth:`run_batches`.
"""
self._reset_llm_usage()
sem = asyncio.Semaphore(max_concurrency)

async def _process(batch: Batch) -> tuple[Batch, list]:
Expand All @@ -429,9 +493,13 @@ async def _process(batch: Batch) -> tuple[Batch, list]:
len(batch.findings),
)
if self._structured_llm:
response = await self._structured_llm.ainvoke(prompt)
response = self._unwrap_structured_response(
await self._structured_llm.ainvoke(prompt)
)
else:
response = _message_text(await self._llm.ainvoke(prompt))
raw_response = await self._llm.ainvoke(prompt)
self._record_usage_from_raw(raw_response)
response = _message_text(raw_response)
logger.debug("LLM response for %s", batch.file_label)
return (batch, self.parse_response(response, batch))

Expand Down
29 changes: 25 additions & 4 deletions src/skillspector/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from typing import NoReturn

from langchain_core.language_models.chat_models import BaseChatModel
from pydantic import BaseModel

from skillspector.model_info import get_max_input_tokens, get_max_output_tokens
from skillspector.providers import (
Expand Down Expand Up @@ -161,11 +162,19 @@ class _StructuredAgentCLIModel:
``complete()``, then parses and validates the response into *schema*.
"""

def __init__(self, provider: object, model: str, max_output_tokens: int, schema: type) -> None:
def __init__(
self,
provider: object,
model: str,
max_output_tokens: int,
schema: type[BaseModel],
include_raw: bool = False,
) -> None:
self._provider = provider
self._model = model
self._max_output_tokens = max_output_tokens
self._schema = schema
self._include_raw = include_raw

def _augment(self, prompt: str) -> str:
schema_json = json.dumps(self._schema.model_json_schema(), indent=2)
Expand All @@ -182,7 +191,17 @@ def invoke(self, prompt: str) -> object:
model=self._model,
max_output_tokens=self._max_output_tokens,
)
return self._schema.model_validate(_extract_json_object(raw))
try:
parsed = self._schema.model_validate(_extract_json_object(raw))
parsing_error = None
except Exception as exc:
parsed = None
parsing_error = exc
if not self._include_raw:
raise
if self._include_raw:
return {"raw": _AgentCLIMessage(raw), "parsed": parsed, "parsing_error": parsing_error}
return parsed

async def ainvoke(self, prompt: str) -> object:
return await asyncio.to_thread(self.invoke, prompt)
Expand Down Expand Up @@ -227,9 +246,11 @@ def invoke(self, prompt: str) -> _AgentCLIMessage:
async def ainvoke(self, prompt: str) -> _AgentCLIMessage:
return await asyncio.to_thread(self.invoke, prompt)

def with_structured_output(self, schema: type) -> _StructuredAgentCLIModel:
def with_structured_output(
self, schema: type[BaseModel], *, include_raw: bool = False
) -> _StructuredAgentCLIModel:
return _StructuredAgentCLIModel(
self._provider, self._model, self._max_output_tokens, schema
self._provider, self._model, self._max_output_tokens, schema, include_raw=include_raw
)


Expand Down
19 changes: 17 additions & 2 deletions src/skillspector/nodes/analyzers/semantic_developer_intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
ANALYZER_ID = "semantic_developer_intent"
logger = get_logger(__name__)


class _NoUsage:
llm_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}


ANALYZER_PROMPT = """\
You are a developer-intent auditor for AI agent skills. Your job is to
detect mismatches between what a skill *claims* to do (its manifest and
Expand Down Expand Up @@ -179,12 +184,22 @@ def node(state: SkillspectorState) -> AnalyzerNodeResponse:
results = asyncio.run(analyzer.arun_batches(batches))
findings = analyzer.collect_findings(results)
logger.info("%s: %d findings", ANALYZER_ID, len(findings))
return {"findings": findings, "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True)]}
return {
"findings": findings,
"llm_call_log": [llm_call_record(ANALYZER_ID, ok=True, **analyzer.llm_usage)],
}
except ValueError:
raise
except Exception as exc:
logger.warning("%s failed: %s", ANALYZER_ID, exc)
return {
"findings": [],
"llm_call_log": [llm_call_record(ANALYZER_ID, ok=False, error=str(exc))],
"llm_call_log": [
llm_call_record(
ANALYZER_ID,
ok=False,
error=str(exc),
**locals().get("analyzer", _NoUsage()).llm_usage,
)
],
}
19 changes: 17 additions & 2 deletions src/skillspector/nodes/analyzers/semantic_quality_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
ANALYZER_ID = "semantic_quality_policy"
logger = get_logger(__name__)


class _NoUsage:
llm_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}


ANALYZER_PROMPT = """\
You are a quality and safety auditor for AI agent skills. Your job is to
review a single skill file and report findings that fall into the categories
Expand Down Expand Up @@ -148,12 +153,22 @@ def node(state: SkillspectorState) -> AnalyzerNodeResponse:
results = asyncio.run(analyzer.arun_batches(batches))
findings = analyzer.collect_findings(results)
logger.info("%s: %d findings", ANALYZER_ID, len(findings))
return {"findings": findings, "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True)]}
return {
"findings": findings,
"llm_call_log": [llm_call_record(ANALYZER_ID, ok=True, **analyzer.llm_usage)],
}
except ValueError:
raise
except Exception as exc:
logger.warning("%s failed: %s", ANALYZER_ID, exc)
return {
"findings": [],
"llm_call_log": [llm_call_record(ANALYZER_ID, ok=False, error=str(exc))],
"llm_call_log": [
llm_call_record(
ANALYZER_ID,
ok=False,
error=str(exc),
**locals().get("analyzer", _NoUsage()).llm_usage,
)
],
}
26 changes: 23 additions & 3 deletions src/skillspector/nodes/analyzers/semantic_security_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
ANALYZER_ID = "semantic_security_discovery"
logger = get_logger(__name__)


class _NoUsage:
llm_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}


ANALYZER_PROMPT = """\
You are a security analyzer for AI agent skill files. Your task is to identify \
**intent and attack-phrasing risks** — issues that evade regex/static detection because \
Expand Down Expand Up @@ -90,14 +95,22 @@ def node(state: SkillspectorState) -> AnalyzerNodeResponse:
results = analyzer.run_batches(batches)
findings = analyzer.collect_findings(results)
logger.info("%s: %d findings", ANALYZER_ID, len(findings))
return {"findings": findings, "llm_call_log": [llm_call_record(ANALYZER_ID, ok=True)]}
return {
"findings": findings,
"llm_call_log": [llm_call_record(ANALYZER_ID, ok=True, **analyzer.llm_usage)],
}
except ValidationError as exc:
# Malformed LLM response — degrade gracefully rather than crashing the graph
logger.warning("%s: LLM returned malformed response: %s", ANALYZER_ID, exc)
return {
"findings": [],
"llm_call_log": [
llm_call_record(ANALYZER_ID, ok=False, error=f"malformed LLM response: {exc}")
llm_call_record(
ANALYZER_ID,
ok=False,
error=f"malformed LLM response: {exc}",
**locals().get("analyzer", _NoUsage()).llm_usage,
)
],
}
except ValueError:
Expand All @@ -106,5 +119,12 @@ def node(state: SkillspectorState) -> AnalyzerNodeResponse:
logger.warning("%s failed: %s", ANALYZER_ID, exc)
return {
"findings": [],
"llm_call_log": [llm_call_record(ANALYZER_ID, ok=False, error=str(exc))],
"llm_call_log": [
llm_call_record(
ANALYZER_ID,
ok=False,
error=str(exc),
**locals().get("analyzer", _NoUsage()).llm_usage,
)
],
}
15 changes: 13 additions & 2 deletions src/skillspector/nodes/meta_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
logger = get_logger(__name__)


class _NoUsage:
llm_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}


# ---------------------------------------------------------------------------
# Structured output schemas
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -565,13 +569,20 @@ def meta_analyzer(state: SkillspectorState) -> MetaAnalyzerResponse:
)
return {
"filtered_findings": filtered,
"llm_call_log": [llm_call_record("meta_analyzer", ok=True)],
"llm_call_log": [llm_call_record("meta_analyzer", ok=True, **analyzer.llm_usage)],
}
except ValueError:
raise
except Exception as e:
logger.warning("LLM call failed, passing all findings through (fail-closed): %s", e)
return {
"filtered_findings": _passthrough_with_defaults(findings),
"llm_call_log": [llm_call_record("meta_analyzer", ok=False, error=str(e))],
"llm_call_log": [
llm_call_record(
"meta_analyzer",
ok=False,
error=str(e),
**locals().get("analyzer", _NoUsage()).llm_usage,
)
],
}
Loading
Loading