diff --git a/CLAUDE.md b/CLAUDE.md index 15e4b1c..a2662a7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -10,6 +10,7 @@ The repo must provide a stable AWS Nitro PCR when the code doesn't change in ord ├── tee_gateway/ # Main application package (Flask/connexion) │ ├── __main__.py # Entry point: app factory, x402 middleware setup, key injection │ ├── llm_backend.py # LLM provider routing via LangChain, HTTP client management +│ ├── image_generation.py # Endpoint-based image gen (/images/generations): request shaping, URL→inline-bytes, signed responses │ ├── tee_manager.py # TEE key generation, nitriding registration, response signing │ ├── model_registry.py # Model config and per-token pricing │ ├── definitions.py # On-chain addresses, network IDs, payment amounts @@ -123,10 +124,16 @@ Model name prefixes determine routing: Image generation via xAI (grok-2-image), ByteDance (seedream-4.0, seedance-4.5), and Z.ai (glm-image) is served through a provider `/images/generations` endpoint rather -than the chat path, but is surfaced on `/v1/chat/completions` exactly like -Gemini's inline-image models (images returned out-of-band under the message -`images` key). These models are billed a flat per-image price (see -`per_image_price_usd` in `model_registry.py`), not per token. +than the chat path (see `image_generation.py`), but is surfaced on +`/v1/chat/completions` exactly like Gemini's inline-image models (images returned +out-of-band under the message `images` key). The client always receives inline +bytes: providers that hand back a hosted URL (Z.ai, Seedance) are fetched inside +the enclave and inlined as `data:` URIs. Image-to-image editing sends the prior +image back inline (a `data:` URI / `image_url` content part on the user turn), +forwarded to providers that support it via the endpoint's `image` field. +Per-provider request quirks (response format, `n`, size/watermark, reference +support) live in `model_registry.py`. These models are billed a flat per-image +price (see `per_image_price_usd`), not per token. ## Verification Examples diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 7494abf..c48d7a6 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -31,11 +31,14 @@ extract_web_search_count, convert_messages, extract_usage, - generate_images, validate_attachments, AttachmentValidationError, canonical_user_content, ) +from tee_gateway.image_generation import ( + create_image_generation_response, + create_image_generation_streaming_response, +) from tee_gateway.model_registry import get_model_config from tee_gateway.pricing import compute_session_cost @@ -77,77 +80,6 @@ def _split_text_and_images(content: Any) -> tuple[str, list[str]]: return ("".join(text_parts), images) -def _extract_image_prompt(langchain_messages: list) -> str: - """Collapse the user-turn text into a single image-generation prompt. - - Image-generation models (xAI Grok, ByteDance Seedream) take a single text - prompt rather than a chat transcript, so we join the text of all human - messages. System/assistant/tool turns are ignored. - - A user turn carrying an attached reference image has list (multimodal) - content; we pull out only its ``text`` parts. (Naively stringifying the list - would splice the base64 image blob into the prompt — the reference image is - forwarded separately via ``_extract_reference_images``.) - """ - from langchain_core.messages import HumanMessage - - parts: list[str] = [] - for m in langchain_messages: - if not isinstance(m, HumanMessage): - continue - content = m.content - if isinstance(content, str): - if content: - parts.append(content) - elif isinstance(content, list): - for part in content: - if isinstance(part, dict): - if part.get("type") == "text" and part.get("text"): - parts.append(part["text"]) - elif isinstance(part, str) and part: - parts.append(part) - return "\n".join(parts) - - -def _extract_reference_images(langchain_messages: list) -> list[str]: - """Collect reference-image URLs/data-URIs from the user turns. - - Endpoint-based image models (Seedream/Seedance) support image-to-image - editing: the client attaches the prior generated image (or an uploaded one) - to the latest user turn as an ``image_url`` content part. We pull those out - so they can be forwarded to the provider as reference images — without them a - follow-up like "add a hat" ignores the previous image and generates from the - prompt text alone. In practice only the latest user turn carries images, so - collecting across turns just yields the active references. - """ - from langchain_core.messages import HumanMessage - - images: list[str] = [] - for m in langchain_messages: - if not isinstance(m, HumanMessage): - continue - content = m.content - if not isinstance(content, list): - continue - for part in content: - if not isinstance(part, dict): - continue - ptype = part.get("type") - if ptype == "image_url": - image_url = part.get("image_url") - url = image_url.get("url") if isinstance(image_url, dict) else image_url - if url: - images.append(url) - elif ptype == "image": - # Standard LangChain image block: inline base64 + mime, or a url. - if part.get("base64"): - mime = part.get("mime_type") or "image/png" - images.append(f"data:{mime};base64,{part['base64']}") - elif part.get("url"): - images.append(part["url"]) - return images - - def create_chat_completion(body): """Create a chat completion (streaming or non-streaming).""" if not connexion.request.is_json: @@ -282,119 +214,6 @@ def _messages_contain_json_word(messages: list) -> bool: return False -def _create_image_generation_response( - chat_request: CreateChatCompletionRequest, request_bytes: bytes -): - """Non-streaming image generation via a provider's images endpoint. - - Surfaces generated images on the message under the ``images`` key exactly - like Gemini's inline-image models. There is no text to sign, so the signature - covers the request hash and an empty output; the image bytes ride out-of-band - inside the OHTTP envelope. Billing is flat per generated image. - """ - langchain_messages = convert_messages(chat_request.messages) - prompt = _extract_image_prompt(langchain_messages) - reference_images = _extract_reference_images(langchain_messages) - images, image_count = generate_images( - chat_request.model, - prompt, - n=chat_request.n or 1, - reference_images=reference_images, - ) - - message_dict: dict[str, Any] = {"role": "assistant", "content": ""} - if images: - message_dict["images"] = images - - timestamp = int(time.time()) - msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( - request_bytes, "", timestamp - ) - tee_keys = get_tee_keys() - signature = tee_keys.sign_data(msg_hash) - - openai_response: dict[str, Any] = { - "id": f"chatcmpl-{uuid.uuid4()}", - "object": "chat.completion", - "created": timestamp, - "model": chat_request.model, - "choices": [{"index": 0, "message": message_dict, "finish_reason": "stop"}], - "tee_signature": signature, - "tee_request_hash": input_hash_hex, - "tee_output_hash": output_hash_hex, - "tee_timestamp": timestamp, - "tee_id": f"0x{tee_keys.get_tee_id()}", - } - - usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - openai_response["usage"] = usage - cost = compute_session_cost(chat_request.model, usage, image_count=image_count) - if cost is not None: - openai_response["opengradient"] = cost.model_dump(mode="json") - - CreateChatCompletionResponse.from_dict(openai_response) - return openai_response - - -def _create_image_generation_streaming_response( - chat_request: CreateChatCompletionRequest, request_bytes: bytes -): - """Streaming image generation: image gen is not a token stream, so we invoke - once and emit the result on the final SSE frame (mirrors the Gemini path).""" - tee_keys = get_tee_keys() - - def generate(): - try: - langchain_messages = convert_messages(chat_request.messages) - prompt = _extract_image_prompt(langchain_messages) - reference_images = _extract_reference_images(langchain_messages) - images, image_count = generate_images( - chat_request.model, - prompt, - n=chat_request.n or 1, - reference_images=reference_images, - ) - - timestamp = int(time.time()) - msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( - request_bytes, "", timestamp - ) - tee_signature = tee_keys.sign_data(msg_hash) - - final_data: dict[str, Any] = { - "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], - "model": chat_request.model, - "tee_signature": tee_signature, - "tee_timestamp": timestamp, - "tee_request_hash": input_hash_hex, - "tee_output_hash": output_hash_hex, - "tee_id": f"0x{tee_keys.get_tee_id()}", - } - # Images travel out-of-band on the final frame; not part of the hash. - if images: - final_data["images"] = images - - usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - final_data["usage"] = usage - cost = compute_session_cost( - chat_request.model, usage, image_count=image_count - ) - if cost is not None: - final_data["opengradient"] = cost.model_dump(mode="json") - - yield f"data: {json.dumps(final_data)}\n\n" - yield "data: [DONE]\n\n" - except Exception as e: - logger.error(f"Image generation streaming error: {str(e)}", exc_info=True) - yield f"data: {json.dumps({'error': 'Stream processing failed', 'exception_type': type(e).__name__})}\n\n" - - return Response( - generate(), - mimetype="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): """Handle non-streaming chat completion via direct LangChain call.""" try: @@ -414,7 +233,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): # surfaced through this same endpoint, returning images out-of-band just # like Gemini's inline-image models. if cfg.image_generation: - return _create_image_generation_response(chat_request, request_bytes) + return create_image_generation_response(chat_request, request_bytes) model = get_chat_model_cached( model=chat_request.model, @@ -571,7 +390,7 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest): # Image-generation models (xAI Grok, ByteDance Seedream) use a dedicated # images endpoint; handle them without building a chat model. if get_model_config(chat_request.model).image_generation: - return _create_image_generation_streaming_response( + return create_image_generation_streaming_response( chat_request, request_bytes ) diff --git a/tee_gateway/image_generation.py b/tee_gateway/image_generation.py new file mode 100644 index 0000000..1e8b1a3 --- /dev/null +++ b/tee_gateway/image_generation.py @@ -0,0 +1,312 @@ +"""Endpoint-based image generation. + +xAI (Aurora), ByteDance (Seedream/Seedance) and Z.ai (GLM-Image) expose image +generation through a dedicated OpenAI-compatible ``POST /images/generations`` +endpoint rather than the chat path. This module owns everything specific to that +flow: + + * ``generate_images`` — shape and send the provider request, always returning + inline ``data:`` URIs (a provider-hosted URL is fetched into the enclave so + the client never sees a raw URL). + * ``create_image_generation_response`` / + ``create_image_generation_streaming_response`` — surface the result on + ``/v1/chat/completions`` exactly like Gemini's inline-image models (images + ride out-of-band under the message ``images`` key), signed and billed flat + per generated image. + +Per-provider request quirks (response format, ``n`` support, reference-image +editing, extra params) live in the model registry, keeping this code flat. +""" + +import base64 +import json +import logging +import time +import uuid +from typing import Any, List, Optional + +import httpx +from flask import Response + +from tee_gateway import llm_backend +from tee_gateway.models.create_chat_completion_request import ( + CreateChatCompletionRequest, +) +from tee_gateway.models.create_chat_completion_response import ( + CreateChatCompletionResponse, +) +from tee_gateway.model_registry import get_model_config +from tee_gateway.pricing import compute_session_cost +from tee_gateway.tee_manager import get_tee_keys, compute_tee_msg_hash + +logger = logging.getLogger(__name__) + +# Endpoint path appended to each provider's OpenAI-compatible base URL. +_IMAGE_GENERATION_PATH = "/images/generations" + +# Provider -> the shared HTTP client attribute on ``llm_backend`` (built after +# key injection). Looked up by attribute so a patched/rebuilt client is picked up. +_IMAGE_CLIENT_ATTRS = { + "x-ai": "xai_http_client", + "bytedance": "bytedance_http_client", + "zai": "zai_http_client", +} + +# Shared keyless client for fetching provider-hosted image URLs into the enclave. +_image_fetch_client: Optional[httpx.Client] = None + + +def _fetch_url_as_data_uri(url: str) -> str: + """Fetch a hosted image URL and return it as a ``data:`` URI. + + Lets URL-returning image providers be surfaced identically to inline-byte + ones: the bytes are pulled into the enclave and ride out inside the OHTTP/TEE + envelope, so the client never sees a raw provider URL. + """ + global _image_fetch_client + if _image_fetch_client is None: + _image_fetch_client = httpx.Client( + timeout=httpx.Timeout(timeout=120.0, connect=15.0), + follow_redirects=True, + ) + resp = _image_fetch_client.get(url) + resp.raise_for_status() + mime = (resp.headers.get("content-type") or "image/jpeg").split(";", 1)[0].strip() + b64 = base64.b64encode(resp.content).decode("ascii") + return f"data:{mime or 'image/jpeg'};base64,{b64}" + + +def generate_images( + model: str, + prompt: str, + n: int = 1, + reference_images: Optional[List[str]] = None, +) -> tuple[list[str], int]: + """Generate images via a provider's OpenAI-compatible images endpoint. + + ``reference_images`` carries input images for image-to-image editing (e.g. a + follow-up "add a hat" that builds on the previously generated image), sent on + the same endpoint via the ``image`` field (a URL or base64 ``data:`` URI, or + an array of up to 10). Models whose endpoint doesn't support it ignore them. + + Returns ``(data_uris, image_count)``. Every entry is a ``data:`` URI — when a + provider returns a hosted URL instead of inline bytes, the gateway fetches it + so the client always receives bytes. The count drives per-image billing. + """ + cfg = get_model_config(model) + provider = cfg.provider + + client_attr = _IMAGE_CLIENT_ATTRS.get(provider) + if client_attr is None: + raise ValueError( + f"Provider {provider!r} does not support the image-generation endpoint" + ) + client: Optional[httpx.Client] = getattr(llm_backend, client_attr, None) + if client is None: + raise RuntimeError(f"{provider} HTTP client has not been initialized") + + # n is clamped to the OpenAI-compatible providers' documented 1..10 range. + count = max(1, min(int(n), 10)) + payload: dict[str, Any] = {"model": cfg.api_name, "prompt": prompt} + if cfg.image_response_format is not None: + payload["response_format"] = cfg.image_response_format + if cfg.image_send_n: + payload["n"] = count + if cfg.image_extra_params: + payload.update(cfg.image_extra_params) + # Image-to-image editing: forward reference images via the ``image`` field (a + # single string, or an array for multi-reference edits, up to 10). Without + # this a follow-up edit like "add a hat" would ignore the prior image and + # generate a fresh one from the prompt text alone. + if reference_images and cfg.image_supports_reference: + refs = reference_images[:10] + payload["image"] = refs[0] if len(refs) == 1 else refs + + logger.info( + "Generating %d image(s) - Provider: %s, Model: %s", + count, + provider, + cfg.api_name, + ) + resp = client.post(_IMAGE_GENERATION_PATH, json=payload) + resp.raise_for_status() + data = resp.json().get("data", []) or [] + + images: list[str] = [] + for item in data: + if not isinstance(item, dict): + continue + b64 = item.get("b64_json") + if b64: + images.append(f"data:image/jpeg;base64,{b64}") + elif item.get("url"): + images.append(_fetch_url_as_data_uri(item["url"])) + + return images, len(images) + + +def _extract_image_inputs(langchain_messages: list) -> tuple[str, list[str]]: + """Pull the text prompt and any reference images out of the user turns. + + Image-generation models take a single text prompt rather than a chat + transcript, so we join the text of all human messages (system/assistant/tool + turns are ignored). A user turn doing an image-to-image edit carries the prior + image as an ``image_url``/``image`` content part alongside its text; we collect + those reference images separately so they can be forwarded to the provider — + without them a follow-up like "add a hat" would ignore the previous image and + generate from the prompt text alone. + + Returns ``(prompt, reference_images)``. + """ + from langchain_core.messages import HumanMessage + + text_parts: list[str] = [] + images: list[str] = [] + for m in langchain_messages: + if not isinstance(m, HumanMessage): + continue + content = m.content + if isinstance(content, str): + if content: + text_parts.append(content) + continue + if not isinstance(content, list): + continue + for part in content: + if isinstance(part, str): + if part: + text_parts.append(part) + continue + if not isinstance(part, dict): + continue + ptype = part.get("type") + if ptype == "text": + if part.get("text"): + text_parts.append(part["text"]) + elif ptype == "image_url": + image_url = part.get("image_url") + url = image_url.get("url") if isinstance(image_url, dict) else image_url + if url: + images.append(url) + elif ptype == "image": + # Standard LangChain image block: inline base64 + mime, or a url. + if part.get("base64"): + mime = part.get("mime_type") or "image/png" + images.append(f"data:{mime};base64,{part['base64']}") + elif part.get("url"): + images.append(part["url"]) + return "\n".join(text_parts), images + + +def _run_image_generation( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +) -> dict[str, Any]: + """Shared core for endpoint-based image generation. + + Generates the images (always inline bytes — the gateway fetches any + provider-hosted URL), signs the response, and computes billing. Image + generation has no text to sign, so the signature covers the request hash and + an empty output; the image bytes ride out-of-band inside the OHTTP envelope. + Billing is flat per generated image. Returns the pieces the streaming and + non-streaming responders assemble into their respective envelopes. + """ + langchain_messages = llm_backend.convert_messages(chat_request.messages) + prompt, reference_images = _extract_image_inputs(langchain_messages) + images, image_count = generate_images( + chat_request.model, + prompt, + n=chat_request.n or 1, + reference_images=reference_images, + ) + + timestamp = int(time.time()) + msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( + request_bytes, "", timestamp + ) + tee_keys = get_tee_keys() + usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + cost = compute_session_cost(chat_request.model, usage, image_count=image_count) + + return { + "images": images, + "usage": usage, + "opengradient": cost.model_dump(mode="json") if cost is not None else None, + "tee_signature": tee_keys.sign_data(msg_hash), + "tee_request_hash": input_hash_hex, + "tee_output_hash": output_hash_hex, + "tee_timestamp": timestamp, + "tee_id": f"0x{tee_keys.get_tee_id()}", + } + + +def create_image_generation_response( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +): + """Non-streaming image generation via a provider's images endpoint. + + Surfaces generated images on the message under the ``images`` key exactly + like Gemini's inline-image models. + """ + result = _run_image_generation(chat_request, request_bytes) + + message_dict: dict[str, Any] = {"role": "assistant", "content": ""} + if result["images"]: + message_dict["images"] = result["images"] + + openai_response: dict[str, Any] = { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": result["tee_timestamp"], + "model": chat_request.model, + "choices": [{"index": 0, "message": message_dict, "finish_reason": "stop"}], + "tee_signature": result["tee_signature"], + "tee_request_hash": result["tee_request_hash"], + "tee_output_hash": result["tee_output_hash"], + "tee_timestamp": result["tee_timestamp"], + "tee_id": result["tee_id"], + "usage": result["usage"], + } + if result["opengradient"] is not None: + openai_response["opengradient"] = result["opengradient"] + + CreateChatCompletionResponse.from_dict(openai_response) + return openai_response + + +def create_image_generation_streaming_response( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +): + """Streaming image generation: image gen is not a token stream, so we invoke + once and emit the result on the final SSE frame (mirrors the Gemini path).""" + + def generate(): + try: + result = _run_image_generation(chat_request, request_bytes) + + final_data: dict[str, Any] = { + "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], + "model": chat_request.model, + "tee_signature": result["tee_signature"], + "tee_timestamp": result["tee_timestamp"], + "tee_request_hash": result["tee_request_hash"], + "tee_output_hash": result["tee_output_hash"], + "tee_id": result["tee_id"], + "usage": result["usage"], + } + # Images travel out-of-band on the final frame; not part of the hash. + if result["images"]: + final_data["images"] = result["images"] + if result["opengradient"] is not None: + final_data["opengradient"] = result["opengradient"] + + yield f"data: {json.dumps(final_data)}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + logger.error(f"Image generation streaming error: {str(e)}", exc_info=True) + yield f"data: {json.dumps({'error': 'Stream processing failed', 'exception_type': type(e).__name__})}\n\n" + + return Response( + generate(), + mimetype="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) diff --git a/tee_gateway/llm_backend.py b/tee_gateway/llm_backend.py index ccbce83..8092a21 100644 --- a/tee_gateway/llm_backend.py +++ b/tee_gateway/llm_backend.py @@ -329,107 +329,6 @@ def get_chat_model_cached( raise ValueError(f"Unsupported provider: {provider}") -# Endpoint path appended to each provider's OpenAI-compatible base URL. -_IMAGE_GENERATION_PATH = "/images/generations" - - -def generate_images( - model: str, - prompt: str, - n: int = 1, - reference_images: Optional[List[str]] = None, -) -> tuple[list[str], int]: - """Generate images via a provider's OpenAI-compatible images endpoint. - - Unlike Gemini's inline-image chat models, xAI (Aurora), ByteDance - (Seedream), and Z.ai (GLM-Image) expose image generation through a dedicated - ``POST /images/generations`` endpoint. We request ``b64_json`` so the image - bytes ride inline inside the OHTTP/TEE envelope for providers that support - it. Z.ai returns temporary image URLs only. - - ``reference_images`` carries input images for image-to-image editing (e.g. a - follow-up "add a hat" that builds on the previously generated image). On - ByteDance Seedream/Seedance these ride on the same endpoint via the ``image`` - field, which accepts a URL or a base64 ``data:`` URI (or an array of them, up - to 10). Providers without image-edit support ignore the references. - - Returns ``(data_uris, image_count)`` where each entry is a ``data:`` URI. The - count is used for per-image billing. Falls back to provider-returned URLs if - a provider ignores ``b64_json``. - """ - cfg = get_model_config(model) - provider = cfg.provider - - if provider == "x-ai": - client = xai_http_client - elif provider == "bytedance": - client = bytedance_http_client - elif provider == "zai": - client = zai_http_client - else: - raise ValueError( - f"Provider {provider!r} does not support the image-generation endpoint" - ) - - if client is None: - raise RuntimeError(f"{provider} HTTP client has not been initialized") - - # n is clamped to the OpenAI-compatible providers' documented 1..10 range. - # Z.ai's GLM-Image and ByteDance Seedance endpoints don't document n/ - # response_format support, so keep their payloads to documented fields. - count = max(1, min(int(n), 10)) - payload: dict[str, Any] = { - "model": cfg.api_name, - "prompt": prompt, - } - if provider == "zai": - payload["size"] = "1280x1280" - elif provider == "bytedance" and cfg.api_name.startswith("ep-"): - # Seedance deployment endpoints use URL format and require extra params. - payload["response_format"] = "url" - payload["sequential_image_generation"] = "disabled" - payload["watermark"] = False - payload["size"] = "2K" - payload["stream"] = False - else: - payload["n"] = count - payload["response_format"] = "b64_json" - - # Image-to-image editing: ByteDance Seedream/Seedance accept reference images - # on the same endpoint via the ``image`` field (URL or base64 data URI; an - # array for multi-reference edits, up to 10). Without this, a follow-up edit - # like "add a hat" would silently ignore the prior image and generate a fresh - # one from the prompt text alone. Only ByteDance is known to support this; - # other providers' text-to-image endpoints reject unknown fields, so gate it. - if reference_images and provider == "bytedance": - refs = reference_images[:10] - payload["image"] = refs[0] if len(refs) == 1 else refs - - logger.info( - "Generating %d image(s) - Provider: %s, Model: %s", - count, - provider, - cfg.api_name, - ) - resp = client.post(_IMAGE_GENERATION_PATH, json=payload) - resp.raise_for_status() - data = resp.json().get("data", []) or [] - - images: list[str] = [] - for item in data: - if not isinstance(item, dict): - continue - b64 = item.get("b64_json") - if b64: - images.append(f"data:image/jpeg;base64,{b64}") - continue - url = item.get("url") - if url: - images.append(url) - - return images, len(images) - - def _parse_data_uri(uri: str) -> Optional[tuple[str, str]]: """Parse a ``data:;base64,`` URI into ``(mime_type, base64_data)``. diff --git a/tee_gateway/model_registry.py b/tee_gateway/model_registry.py index e2098d7..132db19 100644 --- a/tee_gateway/model_registry.py +++ b/tee_gateway/model_registry.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from decimal import Decimal from enum import Enum, unique -from typing import Optional +from typing import Any, Mapping, Optional @dataclass(frozen=True) @@ -35,6 +35,21 @@ class ModelConfig: # Flat USD price per generated image, for ``image_generation`` models. Token # prices are ignored for these models (set to 0 in the registry). per_image_price_usd: Optional[Decimal] = None + # ── /images/generations request shaping (image_generation models only) ── + # The ``response_format`` to request. ``"b64_json"`` returns inline bytes; + # ``"url"`` returns a hosted link the gateway fetches and inlines (so the + # client always receives bytes). ``None`` omits the field for endpoints that + # don't document it (Z.ai GLM-Image). + image_response_format: Optional[str] = "b64_json" + # Whether to send the OpenAI-style ``n`` count. Some endpoints (Z.ai + # GLM-Image, ByteDance Seedance) don't document it and reject/ignore it. + image_send_n: bool = True + # Whether the endpoint accepts reference images for image-to-image editing, + # sent via the ``image`` field. Text-to-image-only endpoints reject it. + image_supports_reference: bool = False + # Static extra params merged verbatim into the request payload (e.g. size, + # watermark). Keyed by field name; values must be JSON-serializable. + image_extra_params: Optional[Mapping[str, Any]] = None # USD per image-modality output token, for ``image_output`` models (Gemini # "nano banana"). These providers bill image output at a higher rate than # text/thinking output: image tokens at this rate, text + thinking tokens at @@ -375,10 +390,12 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.03"), + image_supports_reference=True, ) # Seedance 4.5 image generation via a ModelArk deployment endpoint. - # Uses URL response format and seedance-specific request params - # (sequential_image_generation, watermark, size). Billed per image. + # Returns hosted URLs (fetched and inlined by the gateway) and needs + # seedance-specific request params (sequential_image_generation, watermark, + # size). Billed per image. SEEDANCE_4_5 = ModelConfig( provider="bytedance", api_name="ep-20260624042612-7dxcv", @@ -386,6 +403,15 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.05"), + image_response_format="url", + image_send_n=False, + image_supports_reference=True, + image_extra_params={ + "sequential_image_generation": "disabled", + "watermark": False, + "size": "2K", + "stream": False, + }, ) # ── Nous Research (Nous Portal, OpenAI-compatible) ────────────────── @@ -418,6 +444,8 @@ class SupportedModel(Enum): output_price_usd=Decimal("0.0000044"), ) # GLM-Image uses Z.ai's image endpoint and is billed per generated image. + # Z.ai returns hosted URLs only (fetched and inlined by the gateway) and + # documents neither ``n`` nor ``response_format``, so both are omitted. GLM_IMAGE = ModelConfig( provider="zai", api_name="glm-image", @@ -425,6 +453,9 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.015"), + image_response_format=None, + image_send_n=False, + image_extra_params={"size": "1280x1280"}, ) # ── Legacy models (not in current SDK — retained for older SDK versions) ── diff --git a/tee_gateway/test/test_image_generation.py b/tee_gateway/test/test_image_generation.py index 48c7bfe..140a909 100644 --- a/tee_gateway/test/test_image_generation.py +++ b/tee_gateway/test/test_image_generation.py @@ -6,19 +6,19 @@ endpoint and billed a flat price per generated image. These tests pin: 1. The request/response handling in ``generate_images`` (b64_json -> data URI, - n clamping, url fallback, provider-specific payloads). + n clamping, hosted-URL fetch -> data URI, provider-specific payloads). 2. The flat per-image billing in ``compute_session_cost``. -No network or API key required — the provider HTTP client is mocked and a stub -price feed is injected. +No network or API key required — the provider HTTP client is mocked, the URL +fetch is patched, and a stub price feed is injected. """ import unittest from decimal import Decimal from unittest.mock import MagicMock, patch -from tee_gateway import llm_backend -from tee_gateway.llm_backend import generate_images +from tee_gateway import image_generation, llm_backend +from tee_gateway.image_generation import generate_images from tee_gateway.model_registry import get_model_config from tee_gateway.price_feed import get_price_feed, set_price_feed from tee_gateway.pricing import compute_session_cost @@ -64,23 +64,40 @@ def test_b64_json_becomes_data_uri(self): self.assertEqual(payload["n"], 2) self.assertEqual(payload["response_format"], "b64_json") - def test_url_fallback_when_no_b64(self): + def test_hosted_url_is_fetched_and_inlined(self): + # Providers that return a hosted URL instead of inline bytes are fetched + # into the enclave so the client always receives a data: URI. client = MagicMock() client.post.return_value = _mock_response([{"url": "https://img/1.jpg"}]) - with patch.object(llm_backend, "bytedance_http_client", client): + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ) as fetch, + ): images, count = generate_images(SEEDREAM, "a blue sphere", n=1) self.assertEqual(count, 1) - self.assertEqual(images, ["https://img/1.jpg"]) + self.assertEqual(images, ["data:image/jpeg;base64,RkVUQ0hFRA=="]) + fetch.assert_called_once_with("https://img/1.jpg") - def test_zai_glm_image_uses_documented_payload_and_url_response(self): + def test_zai_glm_image_uses_documented_payload_and_fetches_url(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://z.ai/img.png"}]) - with patch.object(llm_backend, "zai_http_client", client): + with ( + patch.object(llm_backend, "zai_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/png;base64,RkVUQ0hFRA==", + ), + ): images, count = generate_images(GLM_IMAGE, "a poster", n=3) self.assertEqual(count, 1) - self.assertEqual(images, ["https://z.ai/img.png"]) + self.assertEqual(images, ["data:image/png;base64,RkVUQ0hFRA=="]) _, kwargs = client.post.call_args payload = kwargs["json"] @@ -93,11 +110,18 @@ def test_zai_glm_image_uses_documented_payload_and_url_response(self): def test_seedance_uses_url_format_and_extra_params(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://cdn/img.jpg"}]) - with patch.object(llm_backend, "bytedance_http_client", client): + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ), + ): images, count = generate_images(SEEDANCE, "a black hole", n=1) self.assertEqual(count, 1) - self.assertEqual(images, ["https://cdn/img.jpg"]) + self.assertEqual(images, ["data:image/jpeg;base64,RkVUQ0hFRA=="]) _, kwargs = client.post.call_args payload = kwargs["json"] @@ -113,7 +137,14 @@ def test_seedance_uses_url_format_and_extra_params(self): def test_seedance_forwards_single_reference_image(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://cdn/edited.jpg"}]) - with patch.object(llm_backend, "bytedance_http_client", client): + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ), + ): generate_images( SEEDANCE, "add a hat", @@ -174,6 +205,38 @@ def test_uninitialized_client_raises(self): generate_images(GROK_IMAGE, "p", n=1) +class TestFetchUrlAsDataUri(unittest.TestCase): + """The hosted-URL fetch that inlines provider images into the enclave.""" + + def test_encodes_bytes_with_content_type(self): + fetch_client = MagicMock() + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.headers = {"content-type": "image/png; charset=binary"} + resp.content = b"hello" + fetch_client.get.return_value = resp + + with patch.object(image_generation, "_image_fetch_client", fetch_client): + uri = image_generation._fetch_url_as_data_uri("https://cdn/x.png") + + # base64("hello") == "aGVsbG8=", mime taken from content-type (params dropped) + self.assertEqual(uri, "data:image/png;base64,aGVsbG8=") + fetch_client.get.assert_called_once_with("https://cdn/x.png") + + def test_defaults_mime_when_header_missing(self): + fetch_client = MagicMock() + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.headers = {} + resp.content = b"hello" + fetch_client.get.return_value = resp + + with patch.object(image_generation, "_image_fetch_client", fetch_client): + uri = image_generation._fetch_url_as_data_uri("https://cdn/x") + + self.assertEqual(uri, "data:image/jpeg;base64,aGVsbG8=") + + class TestPerImageBilling(unittest.TestCase): """Flat per-image pricing, independent of token usage."""