Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/skillspector/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,28 @@
)


def _resolve_slot_model(slot: str) -> str:
def _resolve_slot_model(slot: str, provider=None) -> str:
"""Resolve the model for *slot* with per-slot env var override support.

Precedence: ``SKILLSPECTOR_MODEL_{SLOT}`` env var > provider
``resolve_model(slot)`` (which itself runs ``SKILLSPECTOR_MODEL`` env >
provider slot default > provider ``DEFAULT_MODEL``).
"""
provider = provider or get_metadata_provider()
env_key = f"SKILLSPECTOR_MODEL_{slot.upper()}"
env_val = os.environ.get(env_key, "").strip()
if env_val:
return env_val
return _provider.resolve_model(slot)
return provider.resolve_model(slot)


MODEL_CONFIG: dict[str, str] = {slot: _resolve_slot_model(slot) for slot in _MODEL_SLOTS}
def build_model_config() -> dict[str, str]:
"""Resolve the model map for the currently active provider."""
provider = get_metadata_provider()
return {slot: _resolve_slot_model(slot, provider) for slot in _MODEL_SLOTS}


MODEL_CONFIG: dict[str, str] = {slot: _resolve_slot_model(slot, _provider) for slot in _MODEL_SLOTS}


def _validate_model_config() -> None:
Expand Down
15 changes: 15 additions & 0 deletions src/skillspector/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
get_active_provider,
get_metadata_provider,
has_cli_capability,
has_provider_binding,
raise_no_llm_api_key_configured,
resolve_chat_model_credentials,
resolve_provider_credentials,
Expand All @@ -71,6 +72,9 @@ def _resolve_llm_credentials() -> tuple[str, str | None]:

def _resolve_default_chat_model() -> str:
"""Return the default chat model for the endpoint that will be used."""
if has_provider_binding():
return get_metadata_provider().resolve_model()

if resolve_provider_credentials() is not None:
return get_metadata_provider().resolve_model()

Expand All @@ -89,6 +93,17 @@ def is_llm_available() -> tuple[bool, str | None]:
auth). For HTTP providers, it falls back to credential resolution.
"""
provider = get_active_provider()
if has_provider_binding():
try:
model = provider.resolve_model()
create_chat_model(
model=model,
max_tokens=get_max_output_tokens(model),
timeout=120,
)
except ValueError as exc:
return False, str(exc)
return True, None
if has_cli_capability(provider):
return provider.is_available() # type: ignore[attr-defined]
try:
Expand Down
4 changes: 2 additions & 2 deletions src/skillspector/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from skillspector import __version__
from skillspector.graph import graph
from skillspector.logging_config import get_logger
from skillspector.providers import resolve_provider_credentials
from skillspector.providers import has_provider_binding, resolve_provider_credentials

if TYPE_CHECKING:
from mcp.server.fastmcp import FastMCP
Expand Down Expand Up @@ -74,7 +74,7 @@ async def run_scan(
if output_format not in VALID_FORMATS:
raise ValueError(f"output_format must be one of {VALID_FORMATS}, got {output_format!r}")

llm_available = resolve_provider_credentials() is not None
llm_available = has_provider_binding() or resolve_provider_credentials() is not None
llm_used = use_llm and llm_available

state: dict[str, Any] = {
Expand Down
4 changes: 2 additions & 2 deletions src/skillspector/nodes/build_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import yaml

from skillspector.constants import MODEL_CONFIG
from skillspector.constants import build_model_config
from skillspector.logging_config import get_logger
from skillspector.state import SkillspectorState

Expand Down Expand Up @@ -246,7 +246,7 @@ def build_context(state: SkillspectorState) -> dict[str, object]:
"ast_cache": {},
"manifest": manifest,
"previous_manifest": None,
"model_config": MODEL_CONFIG,
"model_config": build_model_config(),
"component_metadata": component_metadata,
"has_executable_scripts": has_executable_scripts,
}
34 changes: 34 additions & 0 deletions src/skillspector/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from __future__ import annotations

import os
from contextvars import ContextVar, Token
from typing import NoReturn

from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -67,14 +68,38 @@
"Use --no-llm to skip LLM analysis and run static checks only."
)

_INJECTED_PROVIDER: ContextVar[LLMProvider | None] = ContextVar(
"skillspector_injected_provider",
default=None,
)


def raise_no_llm_api_key_configured() -> NoReturn:
"""Raise the shared no-LLM-credentials error."""
raise ValueError(NO_LLM_API_KEY_MESSAGE)


def use_provider(provider: LLMProvider) -> Token[LLMProvider | None]:
"""Bind *provider* for the current context."""
return _INJECTED_PROVIDER.set(provider)


def reset_provider(token: Token[LLMProvider | None]) -> None:
"""Restore the provider binding represented by *token*."""
_INJECTED_PROVIDER.reset(token)


def has_provider_binding() -> bool:
"""Return whether the current context has an injected provider."""
return _INJECTED_PROVIDER.get() is not None


def _select_active_provider() -> LLMProvider:
"""Construct the active provider based on ``SKILLSPECTOR_PROVIDER``."""
injected_provider = _INJECTED_PROVIDER.get()
if injected_provider is not None:
return injected_provider

name = os.environ.get("SKILLSPECTOR_PROVIDER", "").strip().lower()

if name == "openai":
Expand Down Expand Up @@ -166,6 +191,9 @@ def resolve_chat_model_credentials() -> tuple[str, str | None] | None:
if creds is not None:
return creds

if has_provider_binding():
return None

return _openai_fallback_provider().resolve_credentials()


Expand Down Expand Up @@ -194,6 +222,9 @@ def create_chat_model(
if llm is not None:
return llm

if has_provider_binding():
raise_no_llm_api_key_configured()

from .openai import OpenAIProvider

if not isinstance(provider, OpenAIProvider):
Expand All @@ -219,7 +250,10 @@ def create_chat_model(
"get_active_provider",
"get_metadata_provider",
"has_cli_capability",
"has_provider_binding",
"reset_provider",
"raise_no_llm_api_key_configured",
"resolve_chat_model_credentials",
"resolve_provider_credentials",
"use_provider",
]
31 changes: 31 additions & 0 deletions tests/nodes/test_build_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from skillspector.constants import MODEL_CONFIG
from skillspector.nodes.build_context import build_context
from skillspector.providers import reset_provider, use_provider
from skillspector.state import SkillspectorState


Expand Down Expand Up @@ -131,6 +132,36 @@ def test_build_context_empty_directory_is_valid_empty_scan(tmp_path: Path) -> No
assert result["model_config"] == MODEL_CONFIG


def test_build_context_model_config_uses_bound_provider(tmp_path: Path) -> None:
class _BoundProvider:
DEFAULT_MODEL = "bound-default"
SLOT_DEFAULTS = {"meta_analyzer": "bound-meta"}

def get_context_length(self, model: str) -> int | None:
return 4096

def get_max_output_tokens(self, model: str) -> int | None:
return 128

def resolve_model(self, slot: str = "default") -> str:
return self.SLOT_DEFAULTS.get(slot, self.DEFAULT_MODEL)

def resolve_credentials(self) -> tuple[str, str | None] | None:
return None

def create_chat_model(self, model: str, *, max_tokens: int, timeout: float | None = 120):
return object()

token = use_provider(_BoundProvider())
try:
result = build_context({"skill_path": str(tmp_path)})
finally:
reset_provider(token)

assert result["model_config"]["default"] == "bound-default"
assert result["model_config"]["meta_analyzer"] == "bound-meta"


def test_build_context_skips_skip_dirs(tmp_path: Path) -> None:
"""Skip dirs like __pycache__ and node_modules are not included in components."""
_make_skill_spec_dir(tmp_path)
Expand Down
86 changes: 85 additions & 1 deletion tests/unit/test_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@
get_chat_model,
is_llm_available,
)
from skillspector.providers import NO_LLM_API_KEY_MESSAGE, resolve_provider_credentials
from skillspector.providers import (
NO_LLM_API_KEY_MESSAGE,
reset_provider,
resolve_chat_model_credentials,
resolve_provider_credentials,
use_provider,
)
from skillspector.providers.nv_build import NvBuildProvider
from skillspector.providers.openai import OpenAIProvider

Expand Down Expand Up @@ -120,6 +126,84 @@ def test_get_chat_model_returns_native_anthropic_client(
assert isinstance(llm, ChatAnthropic)
assert llm.model == "claude-opus-4-6"

def test_injected_provider_without_credentials_builds_native_chat_model(self) -> None:
chat_model = object()

class _InjectedProvider:
DEFAULT_MODEL = "injected-default"
SLOT_DEFAULTS = {"meta_analyzer": "injected-default"}

def get_context_length(self, model: str) -> int | None:
return 4096 if model == "injected-default" else None

def get_max_output_tokens(self, model: str) -> int | None:
return 128 if model == "injected-default" else None

def resolve_model(self, slot: str = "default") -> str:
return "injected-default"

def resolve_credentials(self) -> tuple[str, str | None] | None:
return None

def create_chat_model(
self,
model: str,
*,
max_tokens: int,
timeout: float | None = 120,
) -> object:
assert model == "injected-default"
assert max_tokens == 128
assert timeout == 120
return chat_model

token = use_provider(_InjectedProvider())
try:
assert is_llm_available() == (True, None)
assert get_chat_model() is chat_model
finally:
reset_provider(token)

def test_injected_provider_without_native_model_does_not_fall_back_to_openai(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-fallback")

class _InjectedProvider:
DEFAULT_MODEL = "injected-default"
SLOT_DEFAULTS = {}

def get_context_length(self, model: str) -> int | None:
return 4096

def get_max_output_tokens(self, model: str) -> int | None:
return 128

def resolve_model(self, slot: str = "default") -> str:
return "injected-default"

def resolve_credentials(self) -> tuple[str, str | None] | None:
return None

def create_chat_model(
self,
model: str,
*,
max_tokens: int,
timeout: float | None = 120,
) -> object | None:
return None

token = use_provider(_InjectedProvider())
try:
assert resolve_chat_model_credentials() is None
assert is_llm_available() == (False, NO_LLM_API_KEY_MESSAGE)
with pytest.raises(ValueError) as exc_info:
get_chat_model()
assert str(exc_info.value) == NO_LLM_API_KEY_MESSAGE
finally:
reset_provider(token)


class TestFetchModelTokenLimits:
def test_returns_input_and_output_token_pair(self) -> None:
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,35 @@ async def test_run_scan_reports_llm_available_with_credentials(
assert result["scan_mode"] == "static-only"


async def test_run_scan_uses_bound_provider_without_credentials(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""An injected provider can own the LLM client without exposing raw credentials."""

class _Graph:
async def ainvoke(self, state, config):
assert state["use_llm"] is True
return {
"filtered_findings": [],
"risk_score": 0,
"risk_severity": "LOW",
"risk_recommendation": "OK",
"report_body": "report",
}

monkeypatch.setattr(mcp_server, "graph", _Graph())
monkeypatch.setattr(mcp_server, "has_provider_binding", lambda: True)
monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None)
_write_skill(tmp_path)

result = await run_scan(str(tmp_path), use_llm=True, output_format="json")

assert result["llm_available"] is True
assert result["llm_requested"] is True
assert result["llm_used"] is True
assert result["scan_mode"] == "static+llm"


async def test_run_scan_rejects_invalid_format(tmp_path: Path) -> None:
"""An unsupported output_format is rejected before any scan runs."""
with pytest.raises(ValueError):
Expand Down
Loading
Loading