From 172a9fdabf3971ac0b4976883592a371dd3095b2 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 23 Jun 2026 22:59:29 +0000 Subject: [PATCH] Add low/medium/high quality option for image generation Expose a single `quality` request field (low/medium/high) that selects output resolution for image-generation models. Each model maps the tier to its own native size via `image_quality_sizes` in the registry: - ByteDance Seedream / Seedance and Gemini image models: 1K / 2K / 4K - Z.ai GLM-Image: 1024 / 1280 / 2048 px Models with no resolution control (xAI Grok, Gemini 2.5 Flash Image) ignore the option. The `medium` tier mirrors each model's previous default, so omitting `quality` is unchanged. The field is folded into the signed request hash, validated up front (clean 400 on bad values), and documented in the OpenAPI spec. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_01ENGmJDbJJHAM89HSGQuoeF --- CLAUDE.md | 7 +++ tee_gateway/controllers/chat_controller.py | 25 ++++++++-- tee_gateway/llm_backend.py | 44 +++++++++++++--- tee_gateway/model_registry.py | 50 +++++++++++++++++++ .../models/create_chat_completion_request.py | 3 ++ tee_gateway/openapi/openapi.yaml | 16 ++++++ tee_gateway/test/test_image_generation.py | 49 ++++++++++++++++++ 7 files changed, 184 insertions(+), 10 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 15e4b1c..b7a2dd6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -128,6 +128,13 @@ Gemini's inline-image models (images returned out-of-band under the message `images` key). These models are billed a flat per-image price (see `per_image_price_usd` in `model_registry.py`), not per token. +Image output resolution is selected with a single `quality` request field — +`low` / `medium` / `high` — which each model maps to its own native size via +`image_quality_sizes` in `model_registry.py` (Seedream/Seedance/Gemini image: +`1K`/`2K`/`4K`; Z.ai GLM-Image: `1024`/`1280`/`2048` px). Models with no +resolution control (xAI Grok, Gemini 2.5 Flash Image) ignore it; `medium` +mirrors each model's previous default, so omitting `quality` is unchanged. + ## Verification Examples - `examples/verify_attestation.py` — Validates AWS Nitro attestation documents against the root CA diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 1d61621..ca8d1e8 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -36,7 +36,7 @@ AttachmentValidationError, canonical_user_content, ) -from tee_gateway.model_registry import get_model_config +from tee_gateway.model_registry import get_model_config, VALID_IMAGE_QUALITIES from tee_gateway.pricing import compute_session_cost logger = logging.getLogger(__name__) @@ -112,6 +112,17 @@ def create_chat_completion(body): except AttachmentValidationError as e: return {"error": "Invalid attachment", "message": str(e)}, 400 + # Validate the optional image-quality tier up front for a clean 400. Models + # without resolution control silently ignore it (handled downstream). + if chat_request.quality is not None and ( + not isinstance(chat_request.quality, str) + or chat_request.quality.strip().lower() not in VALID_IMAGE_QUALITIES + ): + return { + "error": "Invalid quality", + "message": (f"quality must be one of {', '.join(VALID_IMAGE_QUALITIES)}."), + }, 400 + if chat_request.stream: return _create_streaming_response(chat_request) else: @@ -241,7 +252,7 @@ def _create_image_generation_response( langchain_messages = convert_messages(chat_request.messages) prompt = _extract_image_prompt(langchain_messages) images, image_count = generate_images( - chat_request.model, prompt, n=chat_request.n or 1 + chat_request.model, prompt, n=chat_request.n or 1, quality=chat_request.quality ) message_dict: dict[str, Any] = {"role": "assistant", "content": ""} @@ -290,7 +301,10 @@ def generate(): langchain_messages = convert_messages(chat_request.messages) prompt = _extract_image_prompt(langchain_messages) images, image_count = generate_images( - chat_request.model, prompt, n=chat_request.n or 1 + chat_request.model, + prompt, + n=chat_request.n or 1, + quality=chat_request.quality, ) timestamp = int(time.time()) @@ -361,6 +375,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): else 0.0, max_tokens=chat_request.max_tokens or 4096, web_search=bool(chat_request.web_search), + image_quality=chat_request.quality, ) # Bind user tools and/or the native web search tool if requested. @@ -520,6 +535,7 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest): else 0.0, max_tokens=chat_request.max_tokens or 4096, web_search=bool(chat_request.web_search), + image_quality=chat_request.quality, ) # Bind user tools and/or the native web search tool if requested. @@ -1002,6 +1018,8 @@ def _chat_request_to_dict(chat_request: CreateChatCompletionRequest) -> dict: d["response_format"] = _normalize_response_format(chat_request.response_format) if chat_request.web_search: d["web_search"] = True + if chat_request.quality: + d["quality"] = chat_request.quality return d @@ -1025,6 +1043,7 @@ def _parse_chat_request(chat_request_dict: dict) -> CreateChatCompletionRequest: tool_choice=chat_request_dict.get("tool_choice"), user=chat_request_dict.get("user"), web_search=chat_request_dict.get("web_search", False), + quality=chat_request_dict.get("quality"), ) diff --git a/tee_gateway/llm_backend.py b/tee_gateway/llm_backend.py index 3a34e84..d3e0644 100644 --- a/tee_gateway/llm_backend.py +++ b/tee_gateway/llm_backend.py @@ -26,7 +26,7 @@ from langchain_xai import ChatXAI from tee_gateway.config import ProviderConfig -from tee_gateway.model_registry import get_model_config +from tee_gateway.model_registry import get_model_config, resolve_image_size logger = logging.getLogger(__name__) @@ -150,17 +150,26 @@ def get_provider_from_model(model: str) -> str: @lru_cache(maxsize=64) def get_chat_model_cached( - model: str, temperature: float, max_tokens: int, web_search: bool = False + model: str, + temperature: float, + max_tokens: int, + web_search: bool = False, + image_quality: Optional[str] = None, ): """Get cached chat model instance using the injected ProviderConfig. - Models are cached by (model, temperature, max_tokens, web_search) tuple. - Cache is cleared by set_provider_config() after key injection. + Models are cached by (model, temperature, max_tokens, web_search, + image_quality) tuple. Cache is cleared by set_provider_config() after key + injection. When ``web_search`` is True, provider-specific native web search is enabled. Some providers (OpenAI, xAI) require search configuration at construction time; others (Anthropic, Google) enable it by binding a tool — see ``get_web_search_tool``. Providers without native web search ignore the flag. + + ``image_quality`` ("low" | "medium" | "high") selects the output resolution + for inline image-output models (Gemini "nano banana"); it is ignored by + text models and by image models without resolution control. """ config = _provider_config if config is None: @@ -183,12 +192,18 @@ def get_chat_model_cached( # thinking budget; ask for both TEXT and IMAGE modalities so the model # may caption alongside the generated image. if cfg.image_output: + # Map the requested quality tier to the model's resolution. Models + # without resolution control (e.g. Gemini 2.5 Flash Image) return + # None and use the provider default. + image_size = resolve_image_size(model, image_quality) + image_config = {"image_size": image_size} if image_size else None return ChatGoogleGenerativeAI( model=api_name, google_api_key=config.google_api_key, temperature=effective_temp, max_output_tokens=max_tokens, response_modalities=[Modality.TEXT, Modality.IMAGE], + image_config=image_config, ) return ChatGoogleGenerativeAI( @@ -333,7 +348,9 @@ def get_chat_model_cached( _IMAGE_GENERATION_PATH = "/images/generations" -def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int]: +def generate_images( + model: str, prompt: str, n: int = 1, quality: Optional[str] = None +) -> tuple[list[str], int]: """Generate images via a provider's OpenAI-compatible images endpoint. Unlike Gemini's inline-image chat models, xAI (Aurora), ByteDance @@ -342,6 +359,10 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int bytes ride inline inside the OHTTP/TEE envelope for providers that support it. Z.ai returns temporary image URLs only. + ``quality`` ("low" | "medium" | "high") selects the output resolution where + the model supports it (Seedream/Seedance: 1K/2K/4K, Z.ai: pixel dimensions). + xAI Grok exposes no resolution control and ignores the option. + Returns ``(data_uris, image_count)`` where each entry is a ``data:`` URI. The count is used for per-image billing. Falls back to provider-returned URLs if a provider ignores ``b64_json``. @@ -349,6 +370,11 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int cfg = get_model_config(model) provider = cfg.provider + # Map the requested quality tier to a provider-specific size string. None + # means "no quality requested" or "model has no resolution control"; in both + # cases the provider default below is left in place. + size_override = resolve_image_size(model, quality) + if provider == "x-ai": client = xai_http_client elif provider == "bytedance": @@ -372,17 +398,21 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int "prompt": prompt, } if provider == "zai": - payload["size"] = "1280x1280" + payload["size"] = size_override or "1280x1280" elif provider == "bytedance" and cfg.api_name.startswith("ep-"): # Seedance deployment endpoints use URL format and require extra params. payload["response_format"] = "url" payload["sequential_image_generation"] = "disabled" payload["watermark"] = False - payload["size"] = "2K" + payload["size"] = size_override or "2K" payload["stream"] = False else: payload["n"] = count payload["response_format"] = "b64_json" + # Seedream accepts a size preset; xAI Grok has no resolution control, so + # size_override is None there and the provider default is used. + if size_override: + payload["size"] = size_override logger.info( "Generating %d image(s) - Provider: %s, Model: %s", diff --git a/tee_gateway/model_registry.py b/tee_gateway/model_registry.py index e2098d7..f45f246 100644 --- a/tee_gateway/model_registry.py +++ b/tee_gateway/model_registry.py @@ -47,6 +47,15 @@ class ModelConfig: # means "use the provider default" (see WEB_SEARCH_PRICE_USD_BY_PROVIDER); # set an explicit value here to override a single model's web-search price. web_search_price_usd: Optional[Decimal] = None + # For image models (``image_generation`` endpoint or ``image_output`` inline), + # maps the simple quality tier — ``"low"`` | ``"medium"`` | ``"high"`` — to the + # provider-specific size value sent to the API (e.g. "1K"/"2K"/"4K" for + # ByteDance/Gemini, "1024x1024" pixel dims for Z.ai). ``None`` means the model + # exposes no resolution control and the ``quality`` request option is ignored + # (e.g. xAI Grok and Gemini 2.5 Flash Image, which only emit one resolution). + # The ``"medium"`` tier mirrors each model's previous default size, so omitting + # ``quality`` and passing ``"medium"`` produce the same request. + image_quality_sizes: Optional[dict[str, str]] = None # Default per-search USD price charged when a model uses native web search. @@ -266,6 +275,8 @@ class SupportedModel(Enum): output_price_usd=Decimal("0.000003"), image_output=True, image_output_price_usd=Decimal("0.00006"), + # Nano Banana 2 supports 1K/2K/4K via image_config.image_size (default 1K). + image_quality_sizes={"low": "1K", "medium": "2K", "high": "4K"}, ) GEMINI_3_5_FLASH = ModelConfig( provider="google", @@ -375,6 +386,9 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.03"), + # ModelArk accepts 1K/2K/4K size presets (one dimension fixed to + # 1024/2048/4096); pixels stay within the documented 1280x720..4096x4096. + image_quality_sizes={"low": "1K", "medium": "2K", "high": "4K"}, ) # Seedance 4.5 image generation via a ModelArk deployment endpoint. # Uses URL response format and seedance-specific request params @@ -386,6 +400,9 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.05"), + # Seedance accepts the same 1K/2K/4K size presets as Seedream; medium (2K) + # matches the endpoint's previous hardcoded default. + image_quality_sizes={"low": "1K", "medium": "2K", "high": "4K"}, ) # ── Nous Research (Nous Portal, OpenAI-compatible) ────────────────── @@ -425,6 +442,13 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.015"), + # Z.ai takes explicit pixel dimensions (each 512..2048, multiple of 32); + # medium (1280x1280) matches the endpoint's previous hardcoded default. + image_quality_sizes={ + "low": "1024x1024", + "medium": "1280x1280", + "high": "2048x2048", + }, ) # ── Legacy models (not in current SDK — retained for older SDK versions) ── @@ -568,3 +592,29 @@ def get_web_search_price_usd(model: str) -> Decimal: def provider_supports_web_search(provider: str) -> bool: """Whether the given provider has native web search the gateway can enable.""" return provider in WEB_SEARCH_PRICE_USD_BY_PROVIDER + + +# The simple quality tiers a caller may request for image generation. Each model +# maps these to its own provider-specific resolution (see `image_quality_sizes`). +VALID_IMAGE_QUALITIES = ("low", "medium", "high") + + +def resolve_image_size(model: str, quality: Optional[str]) -> Optional[str]: + """Map a requested ``quality`` tier to a model's provider-specific size string. + + Returns ``None`` when no quality is requested or the model exposes no + resolution control, in which case the provider default is left in place. + Raises ValueError on an unknown model or an invalid quality value. + """ + if quality is None: + return None + normalized = quality.strip().lower() + if normalized not in VALID_IMAGE_QUALITIES: + raise ValueError( + f"Unsupported quality: {quality!r}. " + f"Must be one of {', '.join(VALID_IMAGE_QUALITIES)}." + ) + cfg = get_model_config(model) + if not cfg.image_quality_sizes: + return None + return cfg.image_quality_sizes.get(normalized) diff --git a/tee_gateway/models/create_chat_completion_request.py b/tee_gateway/models/create_chat_completion_request.py index 3d3300b..179ee3b 100644 --- a/tee_gateway/models/create_chat_completion_request.py +++ b/tee_gateway/models/create_chat_completion_request.py @@ -34,6 +34,7 @@ def __init__( function_call=None, functions=None, web_search=False, + quality=None, ): self.messages = messages self.model = model @@ -66,6 +67,7 @@ def __init__( self.function_call = function_call self.functions = functions self.web_search = web_search + self.quality = quality @classmethod def from_dict(cls, dikt) -> "CreateChatCompletionRequest": @@ -103,5 +105,6 @@ def from_dict(cls, dikt) -> "CreateChatCompletionRequest": "function_call", "functions", "web_search", + "quality", } return cls(**{k: v for k, v in dikt.items() if k in known}) diff --git a/tee_gateway/openapi/openapi.yaml b/tee_gateway/openapi/openapi.yaml index 80cb352..1a7e05b 100644 --- a/tee_gateway/openapi/openapi.yaml +++ b/tee_gateway/openapi/openapi.yaml @@ -2895,6 +2895,22 @@ components: nullable: true title: web_search type: boolean + quality: + description: | + Output resolution for image-generation models, as a simple tier. + Each model maps the tier to its own native resolution: ByteDance + Seedream/Seedance and Gemini image models use `1K`/`2K`/`4K`; Z.ai + GLM-Image uses pixel dimensions (1024/1280/2048 px). Models with no + resolution control (xAI Grok, Gemini 2.5 Flash Image) ignore this + option. Has no effect on text models. When omitted, each model's + default resolution (equivalent to `medium`) is used. + enum: + - low + - medium + - high + nullable: true + title: quality + type: string store: default: false description: "Whether or not to store the output of this chat completion\ diff --git a/tee_gateway/test/test_image_generation.py b/tee_gateway/test/test_image_generation.py index 2dc5353..518fd28 100644 --- a/tee_gateway/test/test_image_generation.py +++ b/tee_gateway/test/test_image_generation.py @@ -126,6 +126,55 @@ def test_uninitialized_client_raises(self): with self.assertRaises(RuntimeError): generate_images(GROK_IMAGE, "p", n=1) + def test_quality_maps_to_seedream_size_preset(self): + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + for quality, expected in (("low", "1K"), ("medium", "2K"), ("high", "4K")): + with self.subTest(quality=quality): + with patch.object(llm_backend, "bytedance_http_client", client): + generate_images(SEEDREAM, "p", n=1, quality=quality) + self.assertEqual(client.post.call_args.kwargs["json"]["size"], expected) + + def test_quality_maps_to_zai_pixel_dimensions(self): + client = MagicMock() + client.post.return_value = _mock_response([{"url": "https://z.ai/i.png"}]) + with patch.object(llm_backend, "zai_http_client", client): + generate_images(GLM_IMAGE, "p", n=1, quality="high") + self.assertEqual(client.post.call_args.kwargs["json"]["size"], "2048x2048") + + def test_quality_overrides_seedance_default_size(self): + client = MagicMock() + client.post.return_value = _mock_response([{"url": "https://cdn/i.jpg"}]) + with patch.object(llm_backend, "bytedance_http_client", client): + generate_images(SEEDANCE, "p", n=1, quality="high") + self.assertEqual(client.post.call_args.kwargs["json"]["size"], "4K") + + def test_no_quality_keeps_provider_defaults(self): + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + # Seedream omits size entirely when no quality is requested. + with patch.object(llm_backend, "bytedance_http_client", client): + generate_images(SEEDREAM, "p", n=1) + self.assertNotIn("size", client.post.call_args.kwargs["json"]) + # Z.ai falls back to its documented default. + with patch.object(llm_backend, "zai_http_client", client): + generate_images(GLM_IMAGE, "p", n=1) + self.assertEqual(client.post.call_args.kwargs["json"]["size"], "1280x1280") + + def test_quality_ignored_for_grok_without_resolution_control(self): + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + with patch.object(llm_backend, "xai_http_client", client): + generate_images(GROK_IMAGE, "p", n=1, quality="high") + self.assertNotIn("size", client.post.call_args.kwargs["json"]) + + def test_invalid_quality_raises(self): + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + with patch.object(llm_backend, "bytedance_http_client", client): + with self.assertRaises(ValueError): + generate_images(SEEDREAM, "p", n=1, quality="ultra") + class TestPerImageBilling(unittest.TestCase): """Flat per-image pricing, independent of token usage."""