diff --git a/CLAUDE.md b/CLAUDE.md index e6d0622..a841445 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -10,6 +10,7 @@ The repo must provide a stable AWS Nitro PCR when the code doesn't change in ord ├── tee_gateway/ # Main application package (Flask/connexion) │ ├── __main__.py # Entry point: app factory, x402 middleware setup, key injection │ ├── llm_backend.py # LLM provider routing via LangChain, HTTP client management +│ ├── image_generation.py # Endpoint-based image gen (/images/generations): request shaping, URL→inline-bytes, signed responses │ ├── tee_manager.py # TEE key generation, nitriding registration, response signing │ ├── model_registry.py # Model config and per-token pricing │ ├── definitions.py # On-chain addresses, network IDs, payment amounts @@ -123,11 +124,19 @@ Model name prefixes determine routing: Image generation via xAI (grok-2-image), ByteDance (seedream-4.0, seedream-5.0-lite, seedance-4.5), and Z.ai (glm-image) is served through a -provider `/images/generations` endpoint rather than the chat path, but is -surfaced on `/v1/chat/completions` exactly like Gemini's inline-image models -(images returned out-of-band under the message `images` key). These models are -billed a flat per-image price (see `per_image_price_usd` in -`model_registry.py`), not per token. +provider `/images/generations` endpoint rather than the chat path (see +`image_generation.py`), but is surfaced on `/v1/chat/completions` exactly like +Gemini's inline-image models (images returned out-of-band under the message +`images` key). The client always receives inline bytes: providers that hand back +a hosted URL (Z.ai, Seedance, Seedream 5.0 Lite) are fetched inside the enclave +and inlined as `data:` URIs (the fetch is guarded: http(s) only, non-public IP +hosts rejected, redirects + size capped, and only ever called on provider- +response URLs, never client input). Image-to-image editing sends the prior image +back inline (a `data:` URI / `image_url` content part on the latest user turn), +forwarded to providers that support it via the endpoint's `image` field. +Per-provider request quirks (response format, `n`, size/watermark, reference +support) live in `model_registry.py`. These models are billed a flat per-image +price (see `per_image_price_usd`), not per token. ## Verification Examples diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 1d61621..c48d7a6 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -31,11 +31,14 @@ extract_web_search_count, convert_messages, extract_usage, - generate_images, validate_attachments, AttachmentValidationError, canonical_user_content, ) +from tee_gateway.image_generation import ( + create_image_generation_response, + create_image_generation_streaming_response, +) from tee_gateway.model_registry import get_model_config from tee_gateway.pricing import compute_session_cost @@ -77,23 +80,6 @@ def _split_text_and_images(content: Any) -> tuple[str, list[str]]: return ("".join(text_parts), images) -def _extract_image_prompt(langchain_messages: list) -> str: - """Collapse the user-turn text into a single image-generation prompt. - - Image-generation models (xAI Grok, ByteDance Seedream) take a single text - prompt rather than a chat transcript, so we join the text of all human - messages. System/assistant/tool turns are ignored. - """ - from langchain_core.messages import HumanMessage - - parts: list[str] = [] - for m in langchain_messages: - if isinstance(m, HumanMessage): - content = m.content - parts.append(content if isinstance(content, str) else str(content)) - return "\n".join(p for p in parts if p) - - def create_chat_completion(body): """Create a chat completion (streaming or non-streaming).""" if not connexion.request.is_json: @@ -228,111 +214,6 @@ def _messages_contain_json_word(messages: list) -> bool: return False -def _create_image_generation_response( - chat_request: CreateChatCompletionRequest, request_bytes: bytes -): - """Non-streaming image generation via a provider's images endpoint. - - Surfaces generated images on the message under the ``images`` key exactly - like Gemini's inline-image models. There is no text to sign, so the signature - covers the request hash and an empty output; the image bytes ride out-of-band - inside the OHTTP envelope. Billing is flat per generated image. - """ - langchain_messages = convert_messages(chat_request.messages) - prompt = _extract_image_prompt(langchain_messages) - images, image_count = generate_images( - chat_request.model, prompt, n=chat_request.n or 1 - ) - - message_dict: dict[str, Any] = {"role": "assistant", "content": ""} - if images: - message_dict["images"] = images - - timestamp = int(time.time()) - msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( - request_bytes, "", timestamp - ) - tee_keys = get_tee_keys() - signature = tee_keys.sign_data(msg_hash) - - openai_response: dict[str, Any] = { - "id": f"chatcmpl-{uuid.uuid4()}", - "object": "chat.completion", - "created": timestamp, - "model": chat_request.model, - "choices": [{"index": 0, "message": message_dict, "finish_reason": "stop"}], - "tee_signature": signature, - "tee_request_hash": input_hash_hex, - "tee_output_hash": output_hash_hex, - "tee_timestamp": timestamp, - "tee_id": f"0x{tee_keys.get_tee_id()}", - } - - usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - openai_response["usage"] = usage - cost = compute_session_cost(chat_request.model, usage, image_count=image_count) - if cost is not None: - openai_response["opengradient"] = cost.model_dump(mode="json") - - CreateChatCompletionResponse.from_dict(openai_response) - return openai_response - - -def _create_image_generation_streaming_response( - chat_request: CreateChatCompletionRequest, request_bytes: bytes -): - """Streaming image generation: image gen is not a token stream, so we invoke - once and emit the result on the final SSE frame (mirrors the Gemini path).""" - tee_keys = get_tee_keys() - - def generate(): - try: - langchain_messages = convert_messages(chat_request.messages) - prompt = _extract_image_prompt(langchain_messages) - images, image_count = generate_images( - chat_request.model, prompt, n=chat_request.n or 1 - ) - - timestamp = int(time.time()) - msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( - request_bytes, "", timestamp - ) - tee_signature = tee_keys.sign_data(msg_hash) - - final_data: dict[str, Any] = { - "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], - "model": chat_request.model, - "tee_signature": tee_signature, - "tee_timestamp": timestamp, - "tee_request_hash": input_hash_hex, - "tee_output_hash": output_hash_hex, - "tee_id": f"0x{tee_keys.get_tee_id()}", - } - # Images travel out-of-band on the final frame; not part of the hash. - if images: - final_data["images"] = images - - usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - final_data["usage"] = usage - cost = compute_session_cost( - chat_request.model, usage, image_count=image_count - ) - if cost is not None: - final_data["opengradient"] = cost.model_dump(mode="json") - - yield f"data: {json.dumps(final_data)}\n\n" - yield "data: [DONE]\n\n" - except Exception as e: - logger.error(f"Image generation streaming error: {str(e)}", exc_info=True) - yield f"data: {json.dumps({'error': 'Stream processing failed', 'exception_type': type(e).__name__})}\n\n" - - return Response( - generate(), - mimetype="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): """Handle non-streaming chat completion via direct LangChain call.""" try: @@ -352,7 +233,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): # surfaced through this same endpoint, returning images out-of-band just # like Gemini's inline-image models. if cfg.image_generation: - return _create_image_generation_response(chat_request, request_bytes) + return create_image_generation_response(chat_request, request_bytes) model = get_chat_model_cached( model=chat_request.model, @@ -509,7 +390,7 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest): # Image-generation models (xAI Grok, ByteDance Seedream) use a dedicated # images endpoint; handle them without building a chat model. if get_model_config(chat_request.model).image_generation: - return _create_image_generation_streaming_response( + return create_image_generation_streaming_response( chat_request, request_bytes ) diff --git a/tee_gateway/image_generation.py b/tee_gateway/image_generation.py new file mode 100644 index 0000000..f2841c5 --- /dev/null +++ b/tee_gateway/image_generation.py @@ -0,0 +1,391 @@ +"""Endpoint-based image generation. + +xAI (Aurora), ByteDance (Seedream/Seedance) and Z.ai (GLM-Image) expose image +generation through a dedicated OpenAI-compatible ``POST /images/generations`` +endpoint rather than the chat path. This module owns everything specific to that +flow: + + * ``generate_images`` — shape and send the provider request, always returning + inline ``data:`` URIs (a provider-hosted URL is fetched into the enclave so + the client never sees a raw URL). + * ``create_image_generation_response`` / + ``create_image_generation_streaming_response`` — surface the result on + ``/v1/chat/completions`` exactly like Gemini's inline-image models (images + ride out-of-band under the message ``images`` key), signed and billed flat + per generated image. + +Per-provider request quirks (response format, ``n`` support, reference-image +editing, extra params) live in the model registry, keeping this code flat. +""" + +import base64 +import ipaddress +import json +import logging +import time +import uuid +from typing import Any, List, Optional +from urllib.parse import urlparse + +import httpx +from flask import Response + +from tee_gateway import llm_backend +from tee_gateway.models.create_chat_completion_request import ( + CreateChatCompletionRequest, +) +from tee_gateway.models.create_chat_completion_response import ( + CreateChatCompletionResponse, +) +from tee_gateway.model_registry import get_model_config +from tee_gateway.pricing import compute_session_cost +from tee_gateway.tee_manager import get_tee_keys, compute_tee_msg_hash + +logger = logging.getLogger(__name__) + +# Endpoint path appended to each provider's OpenAI-compatible base URL. +_IMAGE_GENERATION_PATH = "/images/generations" + +# Provider -> the shared HTTP client attribute on ``llm_backend`` (built after +# key injection). Looked up by attribute so a patched/rebuilt client is picked up. +_IMAGE_CLIENT_ATTRS = { + "x-ai": "xai_http_client", + "bytedance": "bytedance_http_client", + "zai": "zai_http_client", +} + +# Bounds on the URL fetch (egress hardening). Provider images are well under the +# size cap; the redirect cap stops a redirect chain from being chased off-host. +_MAX_IMAGE_BYTES = 25 * 1024 * 1024 # 25 MiB +_MAX_REDIRECTS = 3 +_ALLOWED_FETCH_SCHEMES = {"http", "https"} + +# Shared keyless client for fetching provider-hosted image URLs into the enclave. +_image_fetch_client: Optional[httpx.Client] = None + + +def _validate_fetch_url(url: str) -> None: + """Guard the image fetch against SSRF / egress abuse. + + Only http(s) URLs are allowed, and IP-literal hosts in private, loopback, + link-local, or otherwise non-global ranges are rejected. Hostnames are not + resolved here — this helper is only ever called with a URL from a provider's + own ``/images/generations`` response (never client input, which is forwarded + to the provider rather than dereferenced) — so the scheme/IP checks plus the + size and redirect caps in ``_fetch_url_as_data_uri`` bound the blast radius if + that trust is ever misplaced. + """ + parsed = urlparse(url) + if parsed.scheme not in _ALLOWED_FETCH_SCHEMES: + raise ValueError(f"Refusing to fetch image from non-http(s) URL: {url!r}") + host = parsed.hostname + if not host: + raise ValueError(f"Refusing to fetch image from URL without a host: {url!r}") + try: + ip = ipaddress.ip_address(host) + except ValueError: + return # hostname, not an IP literal — left to the provider's CDN + if not ip.is_global: + raise ValueError(f"Refusing to fetch image from non-public address: {host}") + + +def _fetch_url_as_data_uri(url: str) -> str: + """Fetch a provider-hosted image URL and return it as a ``data:`` URI. + + Only ever called with URLs from a provider's ``/images/generations`` response + — never with client-supplied input, which is forwarded to the provider rather + than dereferenced inside the enclave. Lets URL-returning providers be surfaced + identically to inline-byte ones: the bytes are pulled into the enclave and + ride out inside the OHTTP/TEE envelope, so the client never sees a raw URL. + + Hardened against egress abuse: http(s) only, non-public IP hosts rejected, + redirects capped, and the body capped at ``_MAX_IMAGE_BYTES`` (streamed so an + oversized response is aborted instead of buffered whole). + """ + _validate_fetch_url(url) + + global _image_fetch_client + if _image_fetch_client is None: + _image_fetch_client = httpx.Client( + timeout=httpx.Timeout(timeout=120.0, connect=15.0), + follow_redirects=True, + max_redirects=_MAX_REDIRECTS, + ) + + with _image_fetch_client.stream("GET", url) as resp: + resp.raise_for_status() + declared = resp.headers.get("content-length") + if declared and declared.isdigit() and int(declared) > _MAX_IMAGE_BYTES: + raise ValueError( + f"Refusing to fetch image larger than {_MAX_IMAGE_BYTES} bytes" + ) + mime = ( + (resp.headers.get("content-type") or "image/jpeg").split(";", 1)[0].strip() + ) + chunks: list[bytes] = [] + total = 0 + for chunk in resp.iter_bytes(): + total += len(chunk) + if total > _MAX_IMAGE_BYTES: + raise ValueError( + f"Refusing to fetch image larger than {_MAX_IMAGE_BYTES} bytes" + ) + chunks.append(chunk) + + b64 = base64.b64encode(b"".join(chunks)).decode("ascii") + return f"data:{mime or 'image/jpeg'};base64,{b64}" + + +def generate_images( + model: str, + prompt: str, + n: int = 1, + reference_images: Optional[List[str]] = None, +) -> tuple[list[str], int]: + """Generate images via a provider's OpenAI-compatible images endpoint. + + ``reference_images`` carries input images for image-to-image editing (e.g. a + follow-up "add a hat" that builds on the previously generated image), sent on + the same endpoint via the ``image`` field (a URL or base64 ``data:`` URI, or + an array of up to 10). Models whose endpoint doesn't support it ignore them. + + Returns ``(data_uris, image_count)``. Every entry is a ``data:`` URI — when a + provider returns a hosted URL instead of inline bytes, the gateway fetches it + so the client always receives bytes. The count drives per-image billing. + """ + cfg = get_model_config(model) + provider = cfg.provider + + client_attr = _IMAGE_CLIENT_ATTRS.get(provider) + if client_attr is None: + raise ValueError( + f"Provider {provider!r} does not support the image-generation endpoint" + ) + client: Optional[httpx.Client] = getattr(llm_backend, client_attr, None) + if client is None: + raise RuntimeError(f"{provider} HTTP client has not been initialized") + + # n is clamped to the OpenAI-compatible providers' documented 1..10 range. + count = max(1, min(int(n), 10)) + payload: dict[str, Any] = {"model": cfg.api_name, "prompt": prompt} + if cfg.image_response_format is not None: + payload["response_format"] = cfg.image_response_format + if cfg.image_send_n: + payload["n"] = count + if cfg.image_extra_params: + payload.update(cfg.image_extra_params) + # Image-to-image editing: forward reference images via the ``image`` field (a + # single string, or an array for multi-reference edits, up to 10). Without + # this a follow-up edit like "add a hat" would ignore the prior image and + # generate a fresh one from the prompt text alone. Filter to non-empty strings + # so a malformed entry can't break JSON serialization of the request payload. + if reference_images and cfg.image_supports_reference: + refs = [r for r in reference_images if isinstance(r, str) and r][:10] + if refs: + payload["image"] = refs[0] if len(refs) == 1 else refs + + logger.info( + "Generating %d image(s) - Provider: %s, Model: %s", + count, + provider, + cfg.api_name, + ) + resp = client.post(_IMAGE_GENERATION_PATH, json=payload) + resp.raise_for_status() + data = resp.json().get("data", []) or [] + + images: list[str] = [] + for item in data: + if not isinstance(item, dict): + continue + b64 = item.get("b64_json") + if b64: + images.append(f"data:image/jpeg;base64,{b64}") + elif item.get("url"): + images.append(_fetch_url_as_data_uri(item["url"])) + + return images, len(images) + + +def _extract_image_inputs(langchain_messages: list) -> tuple[str, list[str]]: + """Pull the text prompt and any reference images out of the user turns. + + Image-generation models take a single text prompt rather than a chat + transcript, so we join the text of all human messages (system/assistant/tool + turns are ignored). A user turn doing an image-to-image edit carries the prior + image as an ``image_url``/``image`` content part alongside its text; we collect + those reference images so they can be forwarded to the provider — without them + a follow-up like "add a hat" would ignore the previous image and generate from + the prompt text alone. + + Only the *latest* user turn's references are returned: each user turn replaces + the running set (and a text-only turn clears it). Carrying references forward + from earlier edits would forward stale images and burn through the provider's + 10-image cap. Values are coerced/filtered to strings so malformed input can't + break downstream JSON serialization. + + Returns ``(prompt, reference_images)``. + """ + from langchain_core.messages import HumanMessage + + text_parts: list[str] = [] + reference_images: list[str] = [] + for m in langchain_messages: + if not isinstance(m, HumanMessage): + continue + content = m.content + # References reset per user turn so the latest turn is authoritative. + turn_images: list[str] = [] + if isinstance(content, str): + if content: + text_parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, str): + if part: + text_parts.append(part) + continue + if not isinstance(part, dict): + continue + ptype = part.get("type") + if ptype == "text": + text = part.get("text") + if text: + text_parts.append(str(text)) + elif ptype == "image_url": + image_url = part.get("image_url") + url = ( + image_url.get("url") + if isinstance(image_url, dict) + else image_url + ) + if isinstance(url, str) and url: + turn_images.append(url) + elif ptype == "image": + # Standard LangChain image block: inline base64 + mime, or url. + b64 = part.get("base64") + url = part.get("url") + if isinstance(b64, str) and b64: + mime = part.get("mime_type") + mime = mime if isinstance(mime, str) and mime else "image/png" + turn_images.append(f"data:{mime};base64,{b64}") + elif isinstance(url, str) and url: + turn_images.append(url) + else: + continue # unrecognized content shape: no text, leave references as-is + reference_images = turn_images + return "\n".join(text_parts), reference_images + + +def _run_image_generation( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +) -> dict[str, Any]: + """Shared core for endpoint-based image generation. + + Generates the images (always inline bytes — the gateway fetches any + provider-hosted URL), signs the response, and computes billing. Image + generation has no text to sign, so the signature covers the request hash and + an empty output; the image bytes ride out-of-band inside the OHTTP envelope. + Billing is flat per generated image. Returns the pieces the streaming and + non-streaming responders assemble into their respective envelopes. + """ + langchain_messages = llm_backend.convert_messages(chat_request.messages) + prompt, reference_images = _extract_image_inputs(langchain_messages) + images, image_count = generate_images( + chat_request.model, + prompt, + n=chat_request.n or 1, + reference_images=reference_images, + ) + + timestamp = int(time.time()) + msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( + request_bytes, "", timestamp + ) + tee_keys = get_tee_keys() + usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + cost = compute_session_cost(chat_request.model, usage, image_count=image_count) + + return { + "images": images, + "usage": usage, + "opengradient": cost.model_dump(mode="json") if cost is not None else None, + "tee_signature": tee_keys.sign_data(msg_hash), + "tee_request_hash": input_hash_hex, + "tee_output_hash": output_hash_hex, + "tee_timestamp": timestamp, + "tee_id": f"0x{tee_keys.get_tee_id()}", + } + + +def create_image_generation_response( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +): + """Non-streaming image generation via a provider's images endpoint. + + Surfaces generated images on the message under the ``images`` key exactly + like Gemini's inline-image models. + """ + result = _run_image_generation(chat_request, request_bytes) + + message_dict: dict[str, Any] = {"role": "assistant", "content": ""} + if result["images"]: + message_dict["images"] = result["images"] + + openai_response: dict[str, Any] = { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": result["tee_timestamp"], + "model": chat_request.model, + "choices": [{"index": 0, "message": message_dict, "finish_reason": "stop"}], + "tee_signature": result["tee_signature"], + "tee_request_hash": result["tee_request_hash"], + "tee_output_hash": result["tee_output_hash"], + "tee_timestamp": result["tee_timestamp"], + "tee_id": result["tee_id"], + "usage": result["usage"], + } + if result["opengradient"] is not None: + openai_response["opengradient"] = result["opengradient"] + + CreateChatCompletionResponse.from_dict(openai_response) + return openai_response + + +def create_image_generation_streaming_response( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +): + """Streaming image generation: image gen is not a token stream, so we invoke + once and emit the result on the final SSE frame (mirrors the Gemini path).""" + + def generate(): + try: + result = _run_image_generation(chat_request, request_bytes) + + final_data: dict[str, Any] = { + "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], + "model": chat_request.model, + "tee_signature": result["tee_signature"], + "tee_timestamp": result["tee_timestamp"], + "tee_request_hash": result["tee_request_hash"], + "tee_output_hash": result["tee_output_hash"], + "tee_id": result["tee_id"], + "usage": result["usage"], + } + # Images travel out-of-band on the final frame; not part of the hash. + if result["images"]: + final_data["images"] = result["images"] + if result["opengradient"] is not None: + final_data["opengradient"] = result["opengradient"] + + yield f"data: {json.dumps(final_data)}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + logger.error(f"Image generation streaming error: {str(e)}", exc_info=True) + yield f"data: {json.dumps({'error': 'Stream processing failed', 'exception_type': type(e).__name__})}\n\n" + + return Response( + generate(), + mimetype="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) diff --git a/tee_gateway/llm_backend.py b/tee_gateway/llm_backend.py index 3a34e84..8092a21 100644 --- a/tee_gateway/llm_backend.py +++ b/tee_gateway/llm_backend.py @@ -329,86 +329,6 @@ def get_chat_model_cached( raise ValueError(f"Unsupported provider: {provider}") -# Endpoint path appended to each provider's OpenAI-compatible base URL. -_IMAGE_GENERATION_PATH = "/images/generations" - - -def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int]: - """Generate images via a provider's OpenAI-compatible images endpoint. - - Unlike Gemini's inline-image chat models, xAI (Aurora), ByteDance - (Seedream), and Z.ai (GLM-Image) expose image generation through a dedicated - ``POST /images/generations`` endpoint. We request ``b64_json`` so the image - bytes ride inline inside the OHTTP/TEE envelope for providers that support - it. Z.ai returns temporary image URLs only. - - Returns ``(data_uris, image_count)`` where each entry is a ``data:`` URI. The - count is used for per-image billing. Falls back to provider-returned URLs if - a provider ignores ``b64_json``. - """ - cfg = get_model_config(model) - provider = cfg.provider - - if provider == "x-ai": - client = xai_http_client - elif provider == "bytedance": - client = bytedance_http_client - elif provider == "zai": - client = zai_http_client - else: - raise ValueError( - f"Provider {provider!r} does not support the image-generation endpoint" - ) - - if client is None: - raise RuntimeError(f"{provider} HTTP client has not been initialized") - - # n is clamped to the OpenAI-compatible providers' documented 1..10 range. - # Z.ai's GLM-Image and ByteDance Seedance endpoints don't document n/ - # response_format support, so keep their payloads to documented fields. - count = max(1, min(int(n), 10)) - payload: dict[str, Any] = { - "model": cfg.api_name, - "prompt": prompt, - } - if provider == "zai": - payload["size"] = "1280x1280" - elif provider == "bytedance" and cfg.api_name.startswith("ep-"): - # Seedance deployment endpoints use URL format and require extra params. - payload["response_format"] = "url" - payload["sequential_image_generation"] = "disabled" - payload["watermark"] = False - payload["size"] = "2K" - payload["stream"] = False - else: - payload["n"] = count - payload["response_format"] = "b64_json" - - logger.info( - "Generating %d image(s) - Provider: %s, Model: %s", - count, - provider, - cfg.api_name, - ) - resp = client.post(_IMAGE_GENERATION_PATH, json=payload) - resp.raise_for_status() - data = resp.json().get("data", []) or [] - - images: list[str] = [] - for item in data: - if not isinstance(item, dict): - continue - b64 = item.get("b64_json") - if b64: - images.append(f"data:image/jpeg;base64,{b64}") - continue - url = item.get("url") - if url: - images.append(url) - - return images, len(images) - - def _parse_data_uri(uri: str) -> Optional[tuple[str, str]]: """Parse a ``data:;base64,`` URI into ``(mime_type, base64_data)``. diff --git a/tee_gateway/model_registry.py b/tee_gateway/model_registry.py index 4f18b2e..700d40c 100644 --- a/tee_gateway/model_registry.py +++ b/tee_gateway/model_registry.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from decimal import Decimal from enum import Enum, unique -from typing import Optional +from typing import Any, Mapping, Optional @dataclass(frozen=True) @@ -35,6 +35,21 @@ class ModelConfig: # Flat USD price per generated image, for ``image_generation`` models. Token # prices are ignored for these models (set to 0 in the registry). per_image_price_usd: Optional[Decimal] = None + # ── /images/generations request shaping (image_generation models only) ── + # The ``response_format`` to request. ``"b64_json"`` returns inline bytes; + # ``"url"`` returns a hosted link the gateway fetches and inlines (so the + # client always receives bytes). ``None`` omits the field for endpoints that + # don't document it (Z.ai GLM-Image). + image_response_format: Optional[str] = "b64_json" + # Whether to send the OpenAI-style ``n`` count. Some endpoints (Z.ai + # GLM-Image, ByteDance Seedance) don't document it and reject/ignore it. + image_send_n: bool = True + # Whether the endpoint accepts reference images for image-to-image editing, + # sent via the ``image`` field. Text-to-image-only endpoints reject it. + image_supports_reference: bool = False + # Static extra params merged verbatim into the request payload (e.g. size, + # watermark). Keyed by field name; values must be JSON-serializable. + image_extra_params: Optional[Mapping[str, Any]] = None # USD per image-modality output token, for ``image_output`` models (Gemini # "nano banana"). These providers bill image output at a higher rate than # text/thinking output: image tokens at this rate, text + thinking tokens at @@ -64,6 +79,17 @@ class ModelConfig: "google": Decimal("0.035"), } +# ByteDance ModelArk image *deployment* endpoints (api_name "ep-…", e.g. Seedance +# 4.5, Seedream 5.0 Lite) return the URL response format and require these extra +# params. The gateway fetches the returned URL and inlines the bytes, so the +# client still receives inline bytes. Shared so the two ep- models stay in sync. +_BYTEDANCE_EP_IMAGE_PARAMS: dict[str, Any] = { + "sequential_image_generation": "disabled", + "watermark": False, + "size": "2K", + "stream": False, +} + @unique class SupportedModel(Enum): @@ -383,8 +409,12 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.03"), + image_supports_reference=True, ) # Seedream 5.0 Lite image generation via a ModelArk deployment endpoint. + # Seedream 5.0 Lite image generation via a ModelArk deployment endpoint + # (api_name "ep-…"). Like Seedance it returns hosted URLs (fetched and inlined + # by the gateway) and takes the shared ep- deployment params. Billed per image. SEEDREAM_5_0_LITE = ModelConfig( provider="bytedance", api_name="ep-20260624213657-7zc5n", @@ -392,10 +422,14 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.035"), + image_response_format="url", + image_send_n=False, + image_supports_reference=True, + image_extra_params=_BYTEDANCE_EP_IMAGE_PARAMS, ) # Seedance 4.5 image generation via a ModelArk deployment endpoint. - # Uses URL response format and seedance-specific request params - # (sequential_image_generation, watermark, size). Billed per image. + # Returns hosted URLs (fetched and inlined by the gateway) and takes the + # shared ep- deployment params. Billed per image. SEEDANCE_4_5 = ModelConfig( provider="bytedance", api_name="ep-20260624042612-7dxcv", @@ -403,6 +437,10 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.05"), + image_response_format="url", + image_send_n=False, + image_supports_reference=True, + image_extra_params=_BYTEDANCE_EP_IMAGE_PARAMS, ) # ── Nous Research (Nous Portal, OpenAI-compatible) ────────────────── @@ -435,6 +473,8 @@ class SupportedModel(Enum): output_price_usd=Decimal("0.0000044"), ) # GLM-Image uses Z.ai's image endpoint and is billed per generated image. + # Z.ai returns hosted URLs only (fetched and inlined by the gateway) and + # documents neither ``n`` nor ``response_format``, so both are omitted. GLM_IMAGE = ModelConfig( provider="zai", api_name="glm-image", @@ -442,6 +482,9 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.015"), + image_response_format=None, + image_send_n=False, + image_extra_params={"size": "1280x1280"}, ) # ── Legacy models (not in current SDK — retained for older SDK versions) ── diff --git a/tee_gateway/test/test_image_generation.py b/tee_gateway/test/test_image_generation.py index 2dc5353..c6fb573 100644 --- a/tee_gateway/test/test_image_generation.py +++ b/tee_gateway/test/test_image_generation.py @@ -6,25 +6,26 @@ endpoint and billed a flat price per generated image. These tests pin: 1. The request/response handling in ``generate_images`` (b64_json -> data URI, - n clamping, url fallback, provider-specific payloads). + n clamping, hosted-URL fetch -> data URI, provider-specific payloads). 2. The flat per-image billing in ``compute_session_cost``. -No network or API key required — the provider HTTP client is mocked and a stub -price feed is injected. +No network or API key required — the provider HTTP client is mocked, the URL +fetch is patched, and a stub price feed is injected. """ import unittest from decimal import Decimal from unittest.mock import MagicMock, patch -from tee_gateway import llm_backend -from tee_gateway.llm_backend import generate_images +from tee_gateway import image_generation, llm_backend +from tee_gateway.image_generation import generate_images from tee_gateway.model_registry import get_model_config from tee_gateway.price_feed import get_price_feed, set_price_feed from tee_gateway.pricing import compute_session_cost GROK_IMAGE = "grok-2-image" SEEDREAM = "seedream-4.0" +SEEDREAM_5_LITE = "seedream-5.0-lite" SEEDANCE = "seedance-4.5" GLM_IMAGE = "glm-image" @@ -64,23 +65,40 @@ def test_b64_json_becomes_data_uri(self): self.assertEqual(payload["n"], 2) self.assertEqual(payload["response_format"], "b64_json") - def test_url_fallback_when_no_b64(self): + def test_hosted_url_is_fetched_and_inlined(self): + # Providers that return a hosted URL instead of inline bytes are fetched + # into the enclave so the client always receives a data: URI. client = MagicMock() client.post.return_value = _mock_response([{"url": "https://img/1.jpg"}]) - with patch.object(llm_backend, "bytedance_http_client", client): + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ) as fetch, + ): images, count = generate_images(SEEDREAM, "a blue sphere", n=1) self.assertEqual(count, 1) - self.assertEqual(images, ["https://img/1.jpg"]) + self.assertEqual(images, ["data:image/jpeg;base64,RkVUQ0hFRA=="]) + fetch.assert_called_once_with("https://img/1.jpg") - def test_zai_glm_image_uses_documented_payload_and_url_response(self): + def test_zai_glm_image_uses_documented_payload_and_fetches_url(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://z.ai/img.png"}]) - with patch.object(llm_backend, "zai_http_client", client): + with ( + patch.object(llm_backend, "zai_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/png;base64,RkVUQ0hFRA==", + ), + ): images, count = generate_images(GLM_IMAGE, "a poster", n=3) self.assertEqual(count, 1) - self.assertEqual(images, ["https://z.ai/img.png"]) + self.assertEqual(images, ["data:image/png;base64,RkVUQ0hFRA=="]) _, kwargs = client.post.call_args payload = kwargs["json"] @@ -93,11 +111,18 @@ def test_zai_glm_image_uses_documented_payload_and_url_response(self): def test_seedance_uses_url_format_and_extra_params(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://cdn/img.jpg"}]) - with patch.object(llm_backend, "bytedance_http_client", client): + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ), + ): images, count = generate_images(SEEDANCE, "a black hole", n=1) self.assertEqual(count, 1) - self.assertEqual(images, ["https://cdn/img.jpg"]) + self.assertEqual(images, ["data:image/jpeg;base64,RkVUQ0hFRA=="]) _, kwargs = client.post.call_args payload = kwargs["json"] @@ -110,6 +135,87 @@ def test_seedance_uses_url_format_and_extra_params(self): self.assertFalse(payload["stream"]) self.assertNotIn("n", payload) + def test_seedream_5_lite_uses_ep_deployment_params(self): + # Seedream 5.0 Lite is an ep- deployment endpoint and must use the same + # URL/no-n/seedance-style payload as Seedance — a regression guard, since + # this used to be auto-detected from the "ep-" api_name prefix and is now + # driven by explicit registry fields. + client = MagicMock() + client.post.return_value = _mock_response([{"url": "https://cdn/img.jpg"}]) + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ), + ): + images, count = generate_images(SEEDREAM_5_LITE, "a koi pond", n=1) + + self.assertEqual(images, ["data:image/jpeg;base64,RkVUQ0hFRA=="]) + payload = client.post.call_args.kwargs["json"] + self.assertEqual(payload["model"], get_model_config(SEEDREAM_5_LITE).api_name) + self.assertEqual(payload["response_format"], "url") + self.assertEqual(payload["sequential_image_generation"], "disabled") + self.assertFalse(payload["watermark"]) + self.assertEqual(payload["size"], "2K") + self.assertFalse(payload["stream"]) + self.assertNotIn("n", payload) + + def test_seedance_forwards_single_reference_image(self): + client = MagicMock() + client.post.return_value = _mock_response([{"url": "https://cdn/edited.jpg"}]) + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ), + ): + generate_images( + SEEDANCE, + "add a hat", + n=1, + reference_images=["https://cdn/original.jpg"], + ) + + payload = client.post.call_args.kwargs["json"] + # A single reference is sent as a bare string (Seedream/Seedance accept + # either a string or an array for the `image` field). + self.assertEqual(payload["image"], "https://cdn/original.jpg") + + def test_seedream_forwards_multiple_reference_images_as_array(self): + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + refs = ["data:image/png;base64,AAA", "https://cdn/b.jpg"] + with patch.object(llm_backend, "bytedance_http_client", client): + generate_images(SEEDREAM, "fuse these", n=1, reference_images=refs) + + payload = client.post.call_args.kwargs["json"] + self.assertEqual(payload["image"], refs) + + def test_reference_images_clamped_to_ten(self): + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + refs = [f"https://cdn/{i}.jpg" for i in range(15)] + with patch.object(llm_backend, "bytedance_http_client", client): + generate_images(SEEDREAM, "p", n=1, reference_images=refs) + + payload = client.post.call_args.kwargs["json"] + self.assertEqual(len(payload["image"]), 10) + + def test_reference_images_ignored_for_non_bytedance(self): + # xAI/Z.ai text-to-image endpoints don't support image edit; the `image` + # field must not leak into their payloads. + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + with patch.object(llm_backend, "xai_http_client", client): + generate_images( + GROK_IMAGE, "p", n=1, reference_images=["https://cdn/x.jpg"] + ) + self.assertNotIn("image", client.post.call_args.kwargs["json"]) + def test_n_is_clamped_to_provider_range(self): client = MagicMock() client.post.return_value = _mock_response([{"b64_json": "x"}]) @@ -127,6 +233,177 @@ def test_uninitialized_client_raises(self): generate_images(GROK_IMAGE, "p", n=1) +def _mock_stream_client(headers: dict, chunks: list[bytes]) -> MagicMock: + """A fake httpx client whose .stream(...) yields a response with these bytes.""" + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.headers = headers + resp.iter_bytes.return_value = chunks + ctx = MagicMock() + ctx.__enter__.return_value = resp + ctx.__exit__.return_value = False + client = MagicMock() + client.stream.return_value = ctx + return client + + +class TestFetchUrlAsDataUri(unittest.TestCase): + """The hosted-URL fetch that inlines provider images into the enclave.""" + + def test_encodes_bytes_with_content_type(self): + client = _mock_stream_client( + {"content-type": "image/png; charset=binary"}, [b"hel", b"lo"] + ) + with patch.object(image_generation, "_image_fetch_client", client): + uri = image_generation._fetch_url_as_data_uri("https://cdn/x.png") + + # base64("hello") == "aGVsbG8=", mime taken from content-type (params dropped) + self.assertEqual(uri, "data:image/png;base64,aGVsbG8=") + client.stream.assert_called_once_with("GET", "https://cdn/x.png") + + def test_defaults_mime_when_header_missing(self): + client = _mock_stream_client({}, [b"hello"]) + with patch.object(image_generation, "_image_fetch_client", client): + uri = image_generation._fetch_url_as_data_uri("https://cdn/x") + + self.assertEqual(uri, "data:image/jpeg;base64,aGVsbG8=") + + def test_rejects_non_http_scheme(self): + # Validation happens before any client use, so no client is needed. + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri("ftp://cdn/x.png") + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri("file:///etc/passwd") + + def test_rejects_private_and_loopback_ip_hosts(self): + for url in ( + "http://127.0.0.1/x.png", + "http://169.254.169.254/latest/meta-data", # cloud metadata + "http://10.0.0.5/x.png", + "http://[::1]/x.png", + ): + with self.subTest(url=url): + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri(url) + + def test_rejects_body_exceeding_size_cap(self): + client = _mock_stream_client( + {"content-type": "image/png"}, [b"x" * 4, b"x" * 4] + ) + with ( + patch.object(image_generation, "_image_fetch_client", client), + patch.object(image_generation, "_MAX_IMAGE_BYTES", 5), + ): + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri("https://cdn/big.png") + + def test_rejects_declared_content_length_over_cap(self): + client = _mock_stream_client( + {"content-type": "image/png", "content-length": "999"}, [b"x"] + ) + with ( + patch.object(image_generation, "_image_fetch_client", client), + patch.object(image_generation, "_MAX_IMAGE_BYTES", 5), + ): + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri("https://cdn/big.png") + + +class TestExtractImageInputs(unittest.TestCase): + """Prompt + reference-image extraction from the user turns.""" + + @staticmethod + def _human(content): + from langchain_core.messages import HumanMessage + + return HumanMessage(content=content) + + def test_joins_text_across_turns_no_references(self): + msgs = [self._human("a red cube"), self._human("make it blue")] + prompt, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(prompt, "a red cube\nmake it blue") + self.assertEqual(refs, []) + + def test_mixed_text_and_image_does_not_splice_base64_into_prompt(self): + # An image-to-image edit turn: text + an attached reference image. The + # base64 blob must never leak into the prompt text. + data_uri = "data:image/png;base64,QUJD" + msgs = [ + self._human( + [ + {"type": "text", "text": "add a hat"}, + {"type": "image_url", "image_url": {"url": data_uri}}, + ] + ) + ] + prompt, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(prompt, "add a hat") + self.assertEqual(refs, [data_uri]) + + def test_extracts_langchain_image_block_with_base64_and_url(self): + msgs = [ + self._human( + [ + {"type": "image", "base64": "QUJD", "mime_type": "image/webp"}, + {"type": "image", "url": "https://cdn/ref.jpg"}, + ] + ) + ] + _, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(refs, ["data:image/webp;base64,QUJD", "https://cdn/ref.jpg"]) + + def test_only_latest_turn_references_are_returned(self): + # An earlier edit turn carried an image; the latest turn carries a new + # one. Only the latest turn's reference should be forwarded. + msgs = [ + self._human( + [ + {"type": "text", "text": "first"}, + {"type": "image_url", "image_url": {"url": "https://cdn/old.jpg"}}, + ] + ), + self._human( + [ + {"type": "text", "text": "second"}, + {"type": "image_url", "image_url": {"url": "https://cdn/new.jpg"}}, + ] + ), + ] + prompt, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(prompt, "first\nsecond") + self.assertEqual(refs, ["https://cdn/new.jpg"]) + + def test_text_only_latest_turn_clears_stale_references(self): + # Edit turn with an image, then a plain text follow-up: the text-only + # latest turn means a fresh generation, so no stale reference rides along. + msgs = [ + self._human( + [ + {"type": "text", "text": "first"}, + {"type": "image_url", "image_url": {"url": "https://cdn/old.jpg"}}, + ] + ), + self._human("just text now"), + ] + _, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(refs, []) + + def test_malformed_image_parts_are_ignored(self): + msgs = [ + self._human( + [ + {"type": "text", "text": "p"}, + {"type": "image_url", "image_url": {"url": None}}, + {"type": "image_url", "image_url": 123}, + {"type": "image", "base64": None}, + ] + ) + ] + prompt, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(prompt, "p") + self.assertEqual(refs, []) + + class TestPerImageBilling(unittest.TestCase): """Flat per-image pricing, independent of token usage."""