diff --git a/lmms_eval/llm_judge/factory.py b/lmms_eval/llm_judge/factory.py index e884bf3b9..5c1804b76 100644 --- a/lmms_eval/llm_judge/factory.py +++ b/lmms_eval/llm_judge/factory.py @@ -1,13 +1,27 @@ import os +from pathlib import Path from typing import Optional from .base import ServerInterface from .protocol import ServerConfig + +# Load .env for LLM judge configuration +try: + from dotenv import load_dotenv + + for candidate in [Path.cwd() / ".env", Path(__file__).resolve().parents[4] / ".env"]: + if candidate.is_file(): + load_dotenv(candidate, override=False) + break +except ImportError: + pass from .providers import ( AsyncAzureOpenAIProvider, AsyncOpenAIProvider, AzureOpenAIProvider, + BedrockProvider, DummyProvider, + LocalProvider, OpenAIProvider, ) @@ -15,7 +29,15 @@ class ProviderFactory: """Factory for creating judge instances based on configuration""" - _provider_classes = {"openai": OpenAIProvider, "azure": AzureOpenAIProvider, "async_openai": AsyncOpenAIProvider, "async_azure": AsyncAzureOpenAIProvider, "dummy": DummyProvider} + _provider_classes = { + "openai": OpenAIProvider, + "azure": AzureOpenAIProvider, + "async_openai": AsyncOpenAIProvider, + "async_azure": AsyncAzureOpenAIProvider, + "bedrock": BedrockProvider, + "local": LocalProvider, + "dummy": DummyProvider, + } # TODO # This should actually be a decorator that registers the class diff --git a/lmms_eval/llm_judge/providers/__init__.py b/lmms_eval/llm_judge/providers/__init__.py index 9fbdb284d..bc8687f29 100644 --- a/lmms_eval/llm_judge/providers/__init__.py +++ b/lmms_eval/llm_judge/providers/__init__.py @@ -1,7 +1,9 @@ from .async_azure_openai import AsyncAzureOpenAIProvider from .async_openai import AsyncOpenAIProvider from .azure_openai import AzureOpenAIProvider +from .bedrock import BedrockProvider from .dummy import DummyProvider +from .local import LocalProvider from .openai import OpenAIProvider __all__ = [ @@ -9,5 +11,7 @@ "AzureOpenAIProvider", "AsyncOpenAIProvider", "AsyncAzureOpenAIProvider", + "BedrockProvider", + "LocalProvider", "DummyProvider", ] diff --git a/lmms_eval/llm_judge/providers/bedrock.py b/lmms_eval/llm_judge/providers/bedrock.py new file mode 100644 index 000000000..786fbb70b --- /dev/null +++ b/lmms_eval/llm_judge/providers/bedrock.py @@ -0,0 +1,119 @@ +"""AWS Bedrock provider for the llm_judge framework. + +Supports both standard IAM credentials and bearer token auth. + +Environment variables: + AWS_REGION - AWS region (default: us-west-2) + AWS_BEARER_TOKEN_BEDROCK - Bearer token for Bedrock auth (optional) +""" + +import os +import time +from typing import Dict, List, Optional, Union + +from loguru import logger as eval_logger + +from ..base import ServerInterface +from ..protocol import Request, Response, ServerConfig + + +class BedrockProvider(ServerInterface): + """AWS Bedrock implementation of the Judge interface.""" + + def __init__(self, config: Optional[ServerConfig] = None): + super().__init__(config) + self._client = None + + @property + def client(self): + if self._client is None: + import boto3 + from botocore.config import Config + + region = os.getenv("AWS_REGION", "us-west-2") + bearer_token = os.getenv("AWS_BEARER_TOKEN_BEDROCK") + + if bearer_token: + session = boto3.Session() + self._client = session.client( + "bedrock-runtime", + region_name=region, + config=Config(signature_version="bearer"), + aws_access_key_id="unused", + aws_secret_access_key="unused", + aws_session_token=bearer_token, + ) + else: + self._client = boto3.client("bedrock-runtime", region_name=region) + return self._client + + def is_available(self) -> bool: + try: + self.client + return True + except Exception: + return False + + def evaluate(self, request: Request) -> Response: + config = request.config or self.config + messages = self.prepare_messages(request) + + bedrock_messages = [] + for m in messages: + content_blocks = [] + if isinstance(m["content"], str): + content_blocks.append({"text": m["content"]}) + elif isinstance(m["content"], list): + for part in m["content"]: + if part.get("type") == "text": + content_blocks.append({"text": part["text"]}) + elif part.get("type") == "image_url": + # Bedrock expects base64 image in a different format + url = part["image_url"]["url"] + if url.startswith("data:"): + media_type, b64_data = url.split(";base64,", 1) + media_type = media_type.replace("data:", "") + import base64 + content_blocks.append({ + "image": { + "format": media_type.split("/")[-1], + "source": {"bytes": base64.b64decode(b64_data)}, + } + }) + bedrock_messages.append({"role": m["role"], "content": content_blocks}) + + inference_config = { + "maxTokens": config.max_tokens, + "temperature": config.temperature, + } + if config.top_p is not None: + inference_config["topP"] = config.top_p + + for attempt in range(config.num_retries): + try: + response = self.client.converse( + modelId=config.model_name, + messages=bedrock_messages, + inferenceConfig=inference_config, + ) + + content = response["output"]["message"]["content"][0]["text"] + usage = response.get("usage", {}) + + return Response( + content=content.strip(), + model_used=config.model_name, + usage={ + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + }, + raw_response=response, + ) + + except Exception as e: + eval_logger.warning(f"Bedrock attempt {attempt + 1}/{config.num_retries} failed: {e}") + if attempt < config.num_retries - 1: + time.sleep(config.retry_delay) + else: + eval_logger.error(f"All {config.num_retries} Bedrock attempts failed") + raise diff --git a/lmms_eval/llm_judge/providers/local.py b/lmms_eval/llm_judge/providers/local.py new file mode 100644 index 000000000..bfc317059 --- /dev/null +++ b/lmms_eval/llm_judge/providers/local.py @@ -0,0 +1,69 @@ +"""Local vLLM / SGLang provider for the llm_judge framework. + +Connects to any OpenAI-compatible local server without requiring an API key. + +Environment variables: + LLM_JUDGE_URL - Server URL (default: http://localhost:8000/v1/chat/completions) +""" + +import os +import time +from typing import Dict, Optional + +import requests +from loguru import logger as eval_logger + +from ..base import ServerInterface +from ..protocol import Request, Response, ServerConfig + + +class LocalProvider(ServerInterface): + """Local vLLM/SGLang OpenAI-compatible server implementation.""" + + def __init__(self, config: Optional[ServerConfig] = None): + super().__init__(config) + self.api_url = os.getenv("LLM_JUDGE_URL", "http://localhost:8000/v1/chat/completions") + + def is_available(self) -> bool: + try: + resp = requests.get(self.api_url.replace("/chat/completions", "/models"), timeout=5) + return resp.status_code == 200 + except Exception: + return False + + def evaluate(self, request: Request) -> Response: + config = request.config or self.config + messages = self.prepare_messages(request) + + payload = { + "model": config.model_name, + "messages": messages, + "temperature": config.temperature, + "max_tokens": config.max_tokens, + } + if config.top_p is not None: + payload["top_p"] = config.top_p + + for attempt in range(config.num_retries): + try: + resp = requests.post(self.api_url, json=payload, timeout=config.timeout) + resp.raise_for_status() + data = resp.json() + + content = data["choices"][0]["message"]["content"] + usage = data.get("usage") + + return Response( + content=content.strip(), + model_used=data.get("model", config.model_name), + usage=usage, + raw_response=data, + ) + + except Exception as e: + eval_logger.warning(f"Local server attempt {attempt + 1}/{config.num_retries} failed: {e}") + if attempt < config.num_retries - 1: + time.sleep(config.retry_delay) + else: + eval_logger.error(f"All {config.num_retries} local server attempts failed") + raise