From 295fdb8fedbc06471b6fa2188edd4266f4af5572 Mon Sep 17 00:00:00 2001 From: Rod Boev Date: Fri, 3 Jul 2026 22:23:37 -0400 Subject: [PATCH] feat(provider): allow scoped LLM provider injection (#243) Signed-off-by: Rod Boev --- src/skillspector/constants.py | 13 ++- src/skillspector/llm_utils.py | 15 ++++ src/skillspector/mcp_server.py | 4 +- src/skillspector/nodes/build_context.py | 4 +- src/skillspector/providers/__init__.py | 34 +++++++ tests/nodes/test_build_context.py | 31 +++++++ tests/unit/test_llm_utils.py | 86 +++++++++++++++++- tests/unit/test_mcp_server.py | 29 ++++++ tests/unit/test_providers.py | 113 ++++++++++++++++++++++++ 9 files changed, 321 insertions(+), 8 deletions(-) diff --git a/src/skillspector/constants.py b/src/skillspector/constants.py index 375992c7..1446b181 100644 --- a/src/skillspector/constants.py +++ b/src/skillspector/constants.py @@ -50,21 +50,28 @@ ) -def _resolve_slot_model(slot: str) -> str: +def _resolve_slot_model(slot: str, provider=None) -> str: """Resolve the model for *slot* with per-slot env var override support. Precedence: ``SKILLSPECTOR_MODEL_{SLOT}`` env var > provider ``resolve_model(slot)`` (which itself runs ``SKILLSPECTOR_MODEL`` env > provider slot default > provider ``DEFAULT_MODEL``). """ + provider = provider or get_metadata_provider() env_key = f"SKILLSPECTOR_MODEL_{slot.upper()}" env_val = os.environ.get(env_key, "").strip() if env_val: return env_val - return _provider.resolve_model(slot) + return provider.resolve_model(slot) -MODEL_CONFIG: dict[str, str] = {slot: _resolve_slot_model(slot) for slot in _MODEL_SLOTS} +def build_model_config() -> dict[str, str]: + """Resolve the model map for the currently active provider.""" + provider = get_metadata_provider() + return {slot: _resolve_slot_model(slot, provider) for slot in _MODEL_SLOTS} + + +MODEL_CONFIG: dict[str, str] = {slot: _resolve_slot_model(slot, _provider) for slot in _MODEL_SLOTS} def _validate_model_config() -> None: diff --git a/src/skillspector/llm_utils.py b/src/skillspector/llm_utils.py index 468e26b0..f6f8c4fd 100644 --- a/src/skillspector/llm_utils.py +++ b/src/skillspector/llm_utils.py @@ -46,6 +46,7 @@ get_active_provider, get_metadata_provider, has_cli_capability, + has_provider_binding, raise_no_llm_api_key_configured, resolve_chat_model_credentials, resolve_provider_credentials, @@ -71,6 +72,9 @@ def _resolve_llm_credentials() -> tuple[str, str | None]: def _resolve_default_chat_model() -> str: """Return the default chat model for the endpoint that will be used.""" + if has_provider_binding(): + return get_metadata_provider().resolve_model() + if resolve_provider_credentials() is not None: return get_metadata_provider().resolve_model() @@ -89,6 +93,17 @@ def is_llm_available() -> tuple[bool, str | None]: auth). For HTTP providers, it falls back to credential resolution. """ provider = get_active_provider() + if has_provider_binding(): + try: + model = provider.resolve_model() + create_chat_model( + model=model, + max_tokens=get_max_output_tokens(model), + timeout=120, + ) + except ValueError as exc: + return False, str(exc) + return True, None if has_cli_capability(provider): return provider.is_available() # type: ignore[attr-defined] try: diff --git a/src/skillspector/mcp_server.py b/src/skillspector/mcp_server.py index 444b75fc..dcf78880 100644 --- a/src/skillspector/mcp_server.py +++ b/src/skillspector/mcp_server.py @@ -33,7 +33,7 @@ from skillspector import __version__ from skillspector.graph import graph from skillspector.logging_config import get_logger -from skillspector.providers import resolve_provider_credentials +from skillspector.providers import has_provider_binding, resolve_provider_credentials if TYPE_CHECKING: from mcp.server.fastmcp import FastMCP @@ -74,7 +74,7 @@ async def run_scan( if output_format not in VALID_FORMATS: raise ValueError(f"output_format must be one of {VALID_FORMATS}, got {output_format!r}") - llm_available = resolve_provider_credentials() is not None + llm_available = has_provider_binding() or resolve_provider_credentials() is not None llm_used = use_llm and llm_available state: dict[str, Any] = { diff --git a/src/skillspector/nodes/build_context.py b/src/skillspector/nodes/build_context.py index a905844a..d72a7407 100644 --- a/src/skillspector/nodes/build_context.py +++ b/src/skillspector/nodes/build_context.py @@ -26,7 +26,7 @@ import yaml -from skillspector.constants import MODEL_CONFIG +from skillspector.constants import build_model_config from skillspector.logging_config import get_logger from skillspector.state import SkillspectorState @@ -246,7 +246,7 @@ def build_context(state: SkillspectorState) -> dict[str, object]: "ast_cache": {}, "manifest": manifest, "previous_manifest": None, - "model_config": MODEL_CONFIG, + "model_config": build_model_config(), "component_metadata": component_metadata, "has_executable_scripts": has_executable_scripts, } diff --git a/src/skillspector/providers/__init__.py b/src/skillspector/providers/__init__.py index 809884dc..a4c0d709 100644 --- a/src/skillspector/providers/__init__.py +++ b/src/skillspector/providers/__init__.py @@ -46,6 +46,7 @@ from __future__ import annotations import os +from contextvars import ContextVar, Token from typing import NoReturn from langchain_core.language_models.chat_models import BaseChatModel @@ -67,14 +68,38 @@ "Use --no-llm to skip LLM analysis and run static checks only." ) +_INJECTED_PROVIDER: ContextVar[LLMProvider | None] = ContextVar( + "skillspector_injected_provider", + default=None, +) + def raise_no_llm_api_key_configured() -> NoReturn: """Raise the shared no-LLM-credentials error.""" raise ValueError(NO_LLM_API_KEY_MESSAGE) +def use_provider(provider: LLMProvider) -> Token[LLMProvider | None]: + """Bind *provider* for the current context.""" + return _INJECTED_PROVIDER.set(provider) + + +def reset_provider(token: Token[LLMProvider | None]) -> None: + """Restore the provider binding represented by *token*.""" + _INJECTED_PROVIDER.reset(token) + + +def has_provider_binding() -> bool: + """Return whether the current context has an injected provider.""" + return _INJECTED_PROVIDER.get() is not None + + def _select_active_provider() -> LLMProvider: """Construct the active provider based on ``SKILLSPECTOR_PROVIDER``.""" + injected_provider = _INJECTED_PROVIDER.get() + if injected_provider is not None: + return injected_provider + name = os.environ.get("SKILLSPECTOR_PROVIDER", "").strip().lower() if name == "openai": @@ -166,6 +191,9 @@ def resolve_chat_model_credentials() -> tuple[str, str | None] | None: if creds is not None: return creds + if has_provider_binding(): + return None + return _openai_fallback_provider().resolve_credentials() @@ -194,6 +222,9 @@ def create_chat_model( if llm is not None: return llm + if has_provider_binding(): + raise_no_llm_api_key_configured() + from .openai import OpenAIProvider if not isinstance(provider, OpenAIProvider): @@ -219,7 +250,10 @@ def create_chat_model( "get_active_provider", "get_metadata_provider", "has_cli_capability", + "has_provider_binding", + "reset_provider", "raise_no_llm_api_key_configured", "resolve_chat_model_credentials", "resolve_provider_credentials", + "use_provider", ] diff --git a/tests/nodes/test_build_context.py b/tests/nodes/test_build_context.py index d9daca67..6d857efd 100644 --- a/tests/nodes/test_build_context.py +++ b/tests/nodes/test_build_context.py @@ -26,6 +26,7 @@ from skillspector.constants import MODEL_CONFIG from skillspector.nodes.build_context import build_context +from skillspector.providers import reset_provider, use_provider from skillspector.state import SkillspectorState @@ -131,6 +132,36 @@ def test_build_context_empty_directory_is_valid_empty_scan(tmp_path: Path) -> No assert result["model_config"] == MODEL_CONFIG +def test_build_context_model_config_uses_bound_provider(tmp_path: Path) -> None: + class _BoundProvider: + DEFAULT_MODEL = "bound-default" + SLOT_DEFAULTS = {"meta_analyzer": "bound-meta"} + + def get_context_length(self, model: str) -> int | None: + return 4096 + + def get_max_output_tokens(self, model: str) -> int | None: + return 128 + + def resolve_model(self, slot: str = "default") -> str: + return self.SLOT_DEFAULTS.get(slot, self.DEFAULT_MODEL) + + def resolve_credentials(self) -> tuple[str, str | None] | None: + return None + + def create_chat_model(self, model: str, *, max_tokens: int, timeout: float | None = 120): + return object() + + token = use_provider(_BoundProvider()) + try: + result = build_context({"skill_path": str(tmp_path)}) + finally: + reset_provider(token) + + assert result["model_config"]["default"] == "bound-default" + assert result["model_config"]["meta_analyzer"] == "bound-meta" + + def test_build_context_skips_skip_dirs(tmp_path: Path) -> None: """Skip dirs like __pycache__ and node_modules are not included in components.""" _make_skill_spec_dir(tmp_path) diff --git a/tests/unit/test_llm_utils.py b/tests/unit/test_llm_utils.py index 91b09726..6bd26004 100644 --- a/tests/unit/test_llm_utils.py +++ b/tests/unit/test_llm_utils.py @@ -39,7 +39,13 @@ get_chat_model, is_llm_available, ) -from skillspector.providers import NO_LLM_API_KEY_MESSAGE, resolve_provider_credentials +from skillspector.providers import ( + NO_LLM_API_KEY_MESSAGE, + reset_provider, + resolve_chat_model_credentials, + resolve_provider_credentials, + use_provider, +) from skillspector.providers.nv_build import NvBuildProvider from skillspector.providers.openai import OpenAIProvider @@ -120,6 +126,84 @@ def test_get_chat_model_returns_native_anthropic_client( assert isinstance(llm, ChatAnthropic) assert llm.model == "claude-opus-4-6" + def test_injected_provider_without_credentials_builds_native_chat_model(self) -> None: + chat_model = object() + + class _InjectedProvider: + DEFAULT_MODEL = "injected-default" + SLOT_DEFAULTS = {"meta_analyzer": "injected-default"} + + def get_context_length(self, model: str) -> int | None: + return 4096 if model == "injected-default" else None + + def get_max_output_tokens(self, model: str) -> int | None: + return 128 if model == "injected-default" else None + + def resolve_model(self, slot: str = "default") -> str: + return "injected-default" + + def resolve_credentials(self) -> tuple[str, str | None] | None: + return None + + def create_chat_model( + self, + model: str, + *, + max_tokens: int, + timeout: float | None = 120, + ) -> object: + assert model == "injected-default" + assert max_tokens == 128 + assert timeout == 120 + return chat_model + + token = use_provider(_InjectedProvider()) + try: + assert is_llm_available() == (True, None) + assert get_chat_model() is chat_model + finally: + reset_provider(token) + + def test_injected_provider_without_native_model_does_not_fall_back_to_openai( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-fallback") + + class _InjectedProvider: + DEFAULT_MODEL = "injected-default" + SLOT_DEFAULTS = {} + + def get_context_length(self, model: str) -> int | None: + return 4096 + + def get_max_output_tokens(self, model: str) -> int | None: + return 128 + + def resolve_model(self, slot: str = "default") -> str: + return "injected-default" + + def resolve_credentials(self) -> tuple[str, str | None] | None: + return None + + def create_chat_model( + self, + model: str, + *, + max_tokens: int, + timeout: float | None = 120, + ) -> object | None: + return None + + token = use_provider(_InjectedProvider()) + try: + assert resolve_chat_model_credentials() is None + assert is_llm_available() == (False, NO_LLM_API_KEY_MESSAGE) + with pytest.raises(ValueError) as exc_info: + get_chat_model() + assert str(exc_info.value) == NO_LLM_API_KEY_MESSAGE + finally: + reset_provider(token) + class TestFetchModelTokenLimits: def test_returns_input_and_output_token_pair(self) -> None: diff --git a/tests/unit/test_mcp_server.py b/tests/unit/test_mcp_server.py index 10c5596b..e3cf5db0 100644 --- a/tests/unit/test_mcp_server.py +++ b/tests/unit/test_mcp_server.py @@ -77,6 +77,35 @@ async def test_run_scan_reports_llm_available_with_credentials( assert result["scan_mode"] == "static-only" +async def test_run_scan_uses_bound_provider_without_credentials( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """An injected provider can own the LLM client without exposing raw credentials.""" + + class _Graph: + async def ainvoke(self, state, config): + assert state["use_llm"] is True + return { + "filtered_findings": [], + "risk_score": 0, + "risk_severity": "LOW", + "risk_recommendation": "OK", + "report_body": "report", + } + + monkeypatch.setattr(mcp_server, "graph", _Graph()) + monkeypatch.setattr(mcp_server, "has_provider_binding", lambda: True) + monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None) + _write_skill(tmp_path) + + result = await run_scan(str(tmp_path), use_llm=True, output_format="json") + + assert result["llm_available"] is True + assert result["llm_requested"] is True + assert result["llm_used"] is True + assert result["scan_mode"] == "static+llm" + + async def test_run_scan_rejects_invalid_format(tmp_path: Path) -> None: """An unsupported output_format is rejected before any scan runs.""" with pytest.raises(ValueError): diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index 61937409..fae2572f 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -28,14 +28,19 @@ from langchain_anthropic import ChatAnthropic from langchain_openai import ChatOpenAI +import skillspector.providers as providers_module from skillspector.providers import ( NO_LLM_API_KEY_MESSAGE, create_chat_model, + get_active_provider, get_metadata_provider, has_cli_capability, + has_provider_binding, registry, + reset_provider, resolve_chat_model_credentials, resolve_provider_credentials, + use_provider, ) from skillspector.providers.anthropic import AnthropicProvider from skillspector.providers.antigravity_cli import AntigravityCLIProvider @@ -62,6 +67,44 @@ ) +class FakeProvider: + DEFAULT_MODEL = "fake-default" + SLOT_DEFAULTS = {"meta_analyzer": "fake-meta"} + + def __init__( + self, + name: str, + *, + credentials: tuple[str, str | None] | None = None, + chat_model: object | None = None, + ) -> None: + self.name = name + self._credentials = credentials + self.chat_model = chat_model if chat_model is not None else object() + + def get_context_length(self, model: str) -> int | None: + return 111 if model == self.name else None + + def get_max_output_tokens(self, model: str) -> int | None: + return 222 if model == self.name else None + + def resolve_model(self, slot: str = "default") -> str: + return f"{self.name}:{slot}" + + def resolve_credentials(self) -> tuple[str, str | None] | None: + return self._credentials + + def create_chat_model( + self, + model: str, + *, + max_tokens: int, + timeout: float | None = 120, + ) -> object: + self.last_chat_model_request = (model, max_tokens, timeout) + return self.chat_model + + @pytest.fixture(autouse=True) def _clean_provider_env(monkeypatch: pytest.MonkeyPatch): """Isolate provider-related env vars and the YAML cache for each test.""" @@ -74,8 +117,10 @@ def _clean_provider_env(monkeypatch: pytest.MonkeyPatch): monkeypatch.delenv("SKILLSPECTOR_MODEL", raising=False) monkeypatch.delenv("SKILLSPECTOR_MODEL_REGISTRY", raising=False) monkeypatch.delenv("SKILLSPECTOR_PROVIDER", raising=False) + providers_module._INJECTED_PROVIDER.set(None) registry._load.cache_clear() yield + providers_module._INJECTED_PROVIDER.set(None) registry._load.cache_clear() @@ -428,6 +473,74 @@ def test_select_antigravity_cli(self, monkeypatch: pytest.MonkeyPatch) -> None: assert isinstance(provider, AntigravityCLIProvider) assert resolve_provider_credentials() is None + def test_injected_provider_routes_metadata_and_active_helpers(self) -> None: + provider = FakeProvider("injected") + token = use_provider(provider) + try: + assert has_provider_binding() is True + assert get_metadata_provider() is provider + assert get_active_provider() is provider + finally: + reset_provider(token) + assert has_provider_binding() is False + + def test_injected_provider_routes_credentials_and_chat_model( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("SKILLSPECTOR_PROVIDER", "openai") + monkeypatch.setenv("OPENAI_API_KEY", "sk-x") + chat_model = object() + provider = FakeProvider( + "injected", + credentials=("injected-key", "injected-base-url"), + chat_model=chat_model, + ) + token = use_provider(provider) + try: + assert resolve_provider_credentials() == ("injected-key", "injected-base-url") + assert create_chat_model("model-x", max_tokens=42) is chat_model + finally: + reset_provider(token) + + def test_provider_token_reset_restores_env_dispatch( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("SKILLSPECTOR_PROVIDER", "openai") + monkeypatch.setenv("OPENAI_API_KEY", "sk-x") + provider = FakeProvider("injected", credentials=("injected-key", None)) + token = use_provider(provider) + reset_provider(token) + assert isinstance(get_metadata_provider(), OpenAIProvider) + assert resolve_provider_credentials() == ("sk-x", None) + + def test_provider_token_nested_restores_previous_binding( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("SKILLSPECTOR_PROVIDER", "openai") + monkeypatch.setenv("OPENAI_API_KEY", "sk-x") + outer_provider = FakeProvider( + "outer", + credentials=("outer-key", "outer-base-url"), + ) + inner_provider = FakeProvider( + "inner", + credentials=("inner-key", "inner-base-url"), + ) + outer_token = use_provider(outer_provider) + try: + inner_token = use_provider(inner_provider) + try: + assert get_metadata_provider() is inner_provider + assert resolve_provider_credentials() == ("inner-key", "inner-base-url") + finally: + reset_provider(inner_token) + assert get_metadata_provider() is outer_provider + assert resolve_provider_credentials() == ("outer-key", "outer-base-url") + finally: + reset_provider(outer_token) + assert isinstance(get_metadata_provider(), OpenAIProvider) + assert resolve_provider_credentials() == ("sk-x", None) + class TestAntigravityCLIProvider: """Antigravity CLI provider — registered but disabled; must fail closed."""