Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions packages/agent/src/argus_agent/llm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,39 @@ def get_provider() -> LLMProvider:
return _providers[provider_name](settings.llm)


async def get_provider_for_tenant(tenant_id: str) -> LLMProvider:
"""Get an LLM provider using the tenant's BYOK keys if configured,
otherwise fall back to the platform default."""
from argus_agent.api.llm_keys import get_tenant_llm_key

tenant_config = await get_tenant_llm_key(tenant_id)
if tenant_config and tenant_config.get("api_key"):
provider_name = tenant_config.get("provider", "openai")
if provider_name not in _providers:
_discover_providers()
if provider_name not in _providers:
raise ValueError(
f"Unknown LLM provider: {provider_name}. Available: {list(_providers.keys())}"
)

settings = get_settings()
# Build a temporary LLMConfig with tenant's keys
from argus_agent.config import LLMConfig

tenant_llm = LLMConfig(
provider=provider_name,
api_key=tenant_config["api_key"],
model=tenant_config.get("model") or settings.llm.model,
base_url=tenant_config.get("base_url") or settings.llm.base_url,
temperature=settings.llm.temperature,
max_tokens=settings.llm.max_tokens,
)
return _providers[provider_name](tenant_llm)

# No BYOK config — use platform default
return get_provider()


def _discover_providers() -> None:
"""Auto-discover available providers based on installed packages."""
try:
Expand Down
6 changes: 3 additions & 3 deletions packages/agent/src/argus_agent/queue/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ async def on_event(event_type: str, data: dict[str, Any]) -> None:
msg = json.dumps({"event_type": event_type, "data": data})
await redis_pub.publish(f"{STREAM_KEY_PREFIX}{task_id}", msg)

# 4. Get LLM provider
from argus_agent.llm.registry import get_provider
# 4. Get LLM provider (tenant BYOK keys take priority)
from argus_agent.llm.registry import get_provider_for_tenant

provider = get_provider()
provider = await get_provider_for_tenant(payload.tenant_id)

# 5. Build AgentLoop
from argus_agent.agent.loop import AgentLoop
Expand Down
Loading