From 0e7b2da68b644fdc6f2b4fff50455acca267cbf6 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 23 Jun 2026 21:37:34 +0000 Subject: [PATCH 1/3] Forward reference images to Seedream/Seedance image editing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Endpoint-based image-generation models (Seedream, Seedance) are served via the dedicated /images/generations endpoint, which previously only received a text prompt. Two bugs meant follow-up edits like "add a hat" silently ignored the previously generated image: 1. generate_images() never forwarded any input image, so image-to-image editing was impossible — the client attaches the prior image to the latest user turn, but the gateway dropped it. 2. _extract_image_prompt() stringified multimodal list content with str(), splicing the base64 image blob into the prompt text instead of extracting the actual prompt. Now the controller pulls reference images out of the user turns and forwards them to ByteDance's `image` field (URL or base64 data URI, array up to 10), and the prompt extractor reads only text parts. Chat-path image-output models (Gemini "nano banana") already worked because they receive the full multimodal history natively; this brings the images-endpoint models to parity. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_01AwcCXVgsXGqGQGzKWXhMfj --- tee_gateway/controllers/chat_controller.py | 74 ++++++++++++++++++++-- tee_gateway/llm_backend.py | 23 ++++++- tee_gateway/test/test_image_generation.py | 47 ++++++++++++++ 3 files changed, 137 insertions(+), 7 deletions(-) diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 1d61621..7494abf 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -83,15 +83,69 @@ def _extract_image_prompt(langchain_messages: list) -> str: Image-generation models (xAI Grok, ByteDance Seedream) take a single text prompt rather than a chat transcript, so we join the text of all human messages. System/assistant/tool turns are ignored. + + A user turn carrying an attached reference image has list (multimodal) + content; we pull out only its ``text`` parts. (Naively stringifying the list + would splice the base64 image blob into the prompt — the reference image is + forwarded separately via ``_extract_reference_images``.) """ from langchain_core.messages import HumanMessage parts: list[str] = [] for m in langchain_messages: - if isinstance(m, HumanMessage): - content = m.content - parts.append(content if isinstance(content, str) else str(content)) - return "\n".join(p for p in parts if p) + if not isinstance(m, HumanMessage): + continue + content = m.content + if isinstance(content, str): + if content: + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict): + if part.get("type") == "text" and part.get("text"): + parts.append(part["text"]) + elif isinstance(part, str) and part: + parts.append(part) + return "\n".join(parts) + + +def _extract_reference_images(langchain_messages: list) -> list[str]: + """Collect reference-image URLs/data-URIs from the user turns. + + Endpoint-based image models (Seedream/Seedance) support image-to-image + editing: the client attaches the prior generated image (or an uploaded one) + to the latest user turn as an ``image_url`` content part. We pull those out + so they can be forwarded to the provider as reference images — without them a + follow-up like "add a hat" ignores the previous image and generates from the + prompt text alone. In practice only the latest user turn carries images, so + collecting across turns just yields the active references. + """ + from langchain_core.messages import HumanMessage + + images: list[str] = [] + for m in langchain_messages: + if not isinstance(m, HumanMessage): + continue + content = m.content + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict): + continue + ptype = part.get("type") + if ptype == "image_url": + image_url = part.get("image_url") + url = image_url.get("url") if isinstance(image_url, dict) else image_url + if url: + images.append(url) + elif ptype == "image": + # Standard LangChain image block: inline base64 + mime, or a url. + if part.get("base64"): + mime = part.get("mime_type") or "image/png" + images.append(f"data:{mime};base64,{part['base64']}") + elif part.get("url"): + images.append(part["url"]) + return images def create_chat_completion(body): @@ -240,8 +294,12 @@ def _create_image_generation_response( """ langchain_messages = convert_messages(chat_request.messages) prompt = _extract_image_prompt(langchain_messages) + reference_images = _extract_reference_images(langchain_messages) images, image_count = generate_images( - chat_request.model, prompt, n=chat_request.n or 1 + chat_request.model, + prompt, + n=chat_request.n or 1, + reference_images=reference_images, ) message_dict: dict[str, Any] = {"role": "assistant", "content": ""} @@ -289,8 +347,12 @@ def generate(): try: langchain_messages = convert_messages(chat_request.messages) prompt = _extract_image_prompt(langchain_messages) + reference_images = _extract_reference_images(langchain_messages) images, image_count = generate_images( - chat_request.model, prompt, n=chat_request.n or 1 + chat_request.model, + prompt, + n=chat_request.n or 1, + reference_images=reference_images, ) timestamp = int(time.time()) diff --git a/tee_gateway/llm_backend.py b/tee_gateway/llm_backend.py index 3a34e84..ccbce83 100644 --- a/tee_gateway/llm_backend.py +++ b/tee_gateway/llm_backend.py @@ -333,7 +333,12 @@ def get_chat_model_cached( _IMAGE_GENERATION_PATH = "/images/generations" -def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int]: +def generate_images( + model: str, + prompt: str, + n: int = 1, + reference_images: Optional[List[str]] = None, +) -> tuple[list[str], int]: """Generate images via a provider's OpenAI-compatible images endpoint. Unlike Gemini's inline-image chat models, xAI (Aurora), ByteDance @@ -342,6 +347,12 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int bytes ride inline inside the OHTTP/TEE envelope for providers that support it. Z.ai returns temporary image URLs only. + ``reference_images`` carries input images for image-to-image editing (e.g. a + follow-up "add a hat" that builds on the previously generated image). On + ByteDance Seedream/Seedance these ride on the same endpoint via the ``image`` + field, which accepts a URL or a base64 ``data:`` URI (or an array of them, up + to 10). Providers without image-edit support ignore the references. + Returns ``(data_uris, image_count)`` where each entry is a ``data:`` URI. The count is used for per-image billing. Falls back to provider-returned URLs if a provider ignores ``b64_json``. @@ -384,6 +395,16 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int payload["n"] = count payload["response_format"] = "b64_json" + # Image-to-image editing: ByteDance Seedream/Seedance accept reference images + # on the same endpoint via the ``image`` field (URL or base64 data URI; an + # array for multi-reference edits, up to 10). Without this, a follow-up edit + # like "add a hat" would silently ignore the prior image and generate a fresh + # one from the prompt text alone. Only ByteDance is known to support this; + # other providers' text-to-image endpoints reject unknown fields, so gate it. + if reference_images and provider == "bytedance": + refs = reference_images[:10] + payload["image"] = refs[0] if len(refs) == 1 else refs + logger.info( "Generating %d image(s) - Provider: %s, Model: %s", count, diff --git a/tee_gateway/test/test_image_generation.py b/tee_gateway/test/test_image_generation.py index 2dc5353..48c7bfe 100644 --- a/tee_gateway/test/test_image_generation.py +++ b/tee_gateway/test/test_image_generation.py @@ -110,6 +110,53 @@ def test_seedance_uses_url_format_and_extra_params(self): self.assertFalse(payload["stream"]) self.assertNotIn("n", payload) + def test_seedance_forwards_single_reference_image(self): + client = MagicMock() + client.post.return_value = _mock_response([{"url": "https://cdn/edited.jpg"}]) + with patch.object(llm_backend, "bytedance_http_client", client): + generate_images( + SEEDANCE, + "add a hat", + n=1, + reference_images=["https://cdn/original.jpg"], + ) + + payload = client.post.call_args.kwargs["json"] + # A single reference is sent as a bare string (Seedream/Seedance accept + # either a string or an array for the `image` field). + self.assertEqual(payload["image"], "https://cdn/original.jpg") + + def test_seedream_forwards_multiple_reference_images_as_array(self): + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + refs = ["data:image/png;base64,AAA", "https://cdn/b.jpg"] + with patch.object(llm_backend, "bytedance_http_client", client): + generate_images(SEEDREAM, "fuse these", n=1, reference_images=refs) + + payload = client.post.call_args.kwargs["json"] + self.assertEqual(payload["image"], refs) + + def test_reference_images_clamped_to_ten(self): + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + refs = [f"https://cdn/{i}.jpg" for i in range(15)] + with patch.object(llm_backend, "bytedance_http_client", client): + generate_images(SEEDREAM, "p", n=1, reference_images=refs) + + payload = client.post.call_args.kwargs["json"] + self.assertEqual(len(payload["image"]), 10) + + def test_reference_images_ignored_for_non_bytedance(self): + # xAI/Z.ai text-to-image endpoints don't support image edit; the `image` + # field must not leak into their payloads. + client = MagicMock() + client.post.return_value = _mock_response([{"b64_json": "x"}]) + with patch.object(llm_backend, "xai_http_client", client): + generate_images( + GROK_IMAGE, "p", n=1, reference_images=["https://cdn/x.jpg"] + ) + self.assertNotIn("image", client.post.call_args.kwargs["json"]) + def test_n_is_clamped_to_provider_range(self): client = MagicMock() client.post.return_value = _mock_response([{"b64_json": "x"}]) From 56565d00829aeff83d87a454fdda87b58d53724c Mon Sep 17 00:00:00 2001 From: kukac Date: Wed, 24 Jun 2026 10:44:35 -0400 Subject: [PATCH 2/3] Refactor image generation into its own module; always return inline bytes (#110) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The endpoint-based image-generation code (xAI Grok, ByteDance Seedream/Seedance, Z.ai GLM-Image) had grown convoluted and inconsistent, spread across llm_backend.py and chat_controller.py with two near-duplicate streaming/non-streaming responders, two separate message-walk helpers, and a return type that was sometimes inline bytes and sometimes a raw provider URL. Changes: - New tee_gateway/image_generation.py owns the whole flow: request shaping, the provider call, URL→inline-bytes fetching, and the signed chat-completion responders. llm_backend.py and chat_controller.py just route to it. - generate_images() now ALWAYS returns data: URIs. Providers that hand back a hosted URL (Z.ai, Seedance) are fetched inside the enclave and inlined, so the client always receives bytes — never a raw URL. This also matches what the chat-app already expects (it caches data URIs; raw URLs weren't cached). Image-to-image editing already rides inline as a data: URI on the user turn. - Per-provider request quirks (response_format, n support, reference-image editing, extra params like size/watermark) move out of branchy if/elif code into declarative fields on ModelConfig in model_registry.py. - The two streaming/non-streaming responders collapse onto one shared core (_run_image_generation), and the two message-walk helpers collapse into a single _extract_image_inputs pass returning (prompt, reference_images). Net: ~300 lines removed from the two original files, image logic isolated. Tests updated to patch the new URL fetch and assert inline-bytes output; added direct coverage for _fetch_url_as_data_uri. Claude-Session: https://claude.ai/code/session_01BHGhsd68znPWoooHvtLFD5 Co-authored-by: Claude --- CLAUDE.md | 15 +- tee_gateway/controllers/chat_controller.py | 193 +------------ tee_gateway/image_generation.py | 312 +++++++++++++++++++++ tee_gateway/llm_backend.py | 101 ------- tee_gateway/model_registry.py | 37 ++- tee_gateway/test/test_image_generation.py | 91 +++++- 6 files changed, 440 insertions(+), 309 deletions(-) create mode 100644 tee_gateway/image_generation.py diff --git a/CLAUDE.md b/CLAUDE.md index 15e4b1c..a2662a7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -10,6 +10,7 @@ The repo must provide a stable AWS Nitro PCR when the code doesn't change in ord ├── tee_gateway/ # Main application package (Flask/connexion) │ ├── __main__.py # Entry point: app factory, x402 middleware setup, key injection │ ├── llm_backend.py # LLM provider routing via LangChain, HTTP client management +│ ├── image_generation.py # Endpoint-based image gen (/images/generations): request shaping, URL→inline-bytes, signed responses │ ├── tee_manager.py # TEE key generation, nitriding registration, response signing │ ├── model_registry.py # Model config and per-token pricing │ ├── definitions.py # On-chain addresses, network IDs, payment amounts @@ -123,10 +124,16 @@ Model name prefixes determine routing: Image generation via xAI (grok-2-image), ByteDance (seedream-4.0, seedance-4.5), and Z.ai (glm-image) is served through a provider `/images/generations` endpoint rather -than the chat path, but is surfaced on `/v1/chat/completions` exactly like -Gemini's inline-image models (images returned out-of-band under the message -`images` key). These models are billed a flat per-image price (see -`per_image_price_usd` in `model_registry.py`), not per token. +than the chat path (see `image_generation.py`), but is surfaced on +`/v1/chat/completions` exactly like Gemini's inline-image models (images returned +out-of-band under the message `images` key). The client always receives inline +bytes: providers that hand back a hosted URL (Z.ai, Seedance) are fetched inside +the enclave and inlined as `data:` URIs. Image-to-image editing sends the prior +image back inline (a `data:` URI / `image_url` content part on the user turn), +forwarded to providers that support it via the endpoint's `image` field. +Per-provider request quirks (response format, `n`, size/watermark, reference +support) live in `model_registry.py`. These models are billed a flat per-image +price (see `per_image_price_usd`), not per token. ## Verification Examples diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 7494abf..c48d7a6 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -31,11 +31,14 @@ extract_web_search_count, convert_messages, extract_usage, - generate_images, validate_attachments, AttachmentValidationError, canonical_user_content, ) +from tee_gateway.image_generation import ( + create_image_generation_response, + create_image_generation_streaming_response, +) from tee_gateway.model_registry import get_model_config from tee_gateway.pricing import compute_session_cost @@ -77,77 +80,6 @@ def _split_text_and_images(content: Any) -> tuple[str, list[str]]: return ("".join(text_parts), images) -def _extract_image_prompt(langchain_messages: list) -> str: - """Collapse the user-turn text into a single image-generation prompt. - - Image-generation models (xAI Grok, ByteDance Seedream) take a single text - prompt rather than a chat transcript, so we join the text of all human - messages. System/assistant/tool turns are ignored. - - A user turn carrying an attached reference image has list (multimodal) - content; we pull out only its ``text`` parts. (Naively stringifying the list - would splice the base64 image blob into the prompt — the reference image is - forwarded separately via ``_extract_reference_images``.) - """ - from langchain_core.messages import HumanMessage - - parts: list[str] = [] - for m in langchain_messages: - if not isinstance(m, HumanMessage): - continue - content = m.content - if isinstance(content, str): - if content: - parts.append(content) - elif isinstance(content, list): - for part in content: - if isinstance(part, dict): - if part.get("type") == "text" and part.get("text"): - parts.append(part["text"]) - elif isinstance(part, str) and part: - parts.append(part) - return "\n".join(parts) - - -def _extract_reference_images(langchain_messages: list) -> list[str]: - """Collect reference-image URLs/data-URIs from the user turns. - - Endpoint-based image models (Seedream/Seedance) support image-to-image - editing: the client attaches the prior generated image (or an uploaded one) - to the latest user turn as an ``image_url`` content part. We pull those out - so they can be forwarded to the provider as reference images — without them a - follow-up like "add a hat" ignores the previous image and generates from the - prompt text alone. In practice only the latest user turn carries images, so - collecting across turns just yields the active references. - """ - from langchain_core.messages import HumanMessage - - images: list[str] = [] - for m in langchain_messages: - if not isinstance(m, HumanMessage): - continue - content = m.content - if not isinstance(content, list): - continue - for part in content: - if not isinstance(part, dict): - continue - ptype = part.get("type") - if ptype == "image_url": - image_url = part.get("image_url") - url = image_url.get("url") if isinstance(image_url, dict) else image_url - if url: - images.append(url) - elif ptype == "image": - # Standard LangChain image block: inline base64 + mime, or a url. - if part.get("base64"): - mime = part.get("mime_type") or "image/png" - images.append(f"data:{mime};base64,{part['base64']}") - elif part.get("url"): - images.append(part["url"]) - return images - - def create_chat_completion(body): """Create a chat completion (streaming or non-streaming).""" if not connexion.request.is_json: @@ -282,119 +214,6 @@ def _messages_contain_json_word(messages: list) -> bool: return False -def _create_image_generation_response( - chat_request: CreateChatCompletionRequest, request_bytes: bytes -): - """Non-streaming image generation via a provider's images endpoint. - - Surfaces generated images on the message under the ``images`` key exactly - like Gemini's inline-image models. There is no text to sign, so the signature - covers the request hash and an empty output; the image bytes ride out-of-band - inside the OHTTP envelope. Billing is flat per generated image. - """ - langchain_messages = convert_messages(chat_request.messages) - prompt = _extract_image_prompt(langchain_messages) - reference_images = _extract_reference_images(langchain_messages) - images, image_count = generate_images( - chat_request.model, - prompt, - n=chat_request.n or 1, - reference_images=reference_images, - ) - - message_dict: dict[str, Any] = {"role": "assistant", "content": ""} - if images: - message_dict["images"] = images - - timestamp = int(time.time()) - msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( - request_bytes, "", timestamp - ) - tee_keys = get_tee_keys() - signature = tee_keys.sign_data(msg_hash) - - openai_response: dict[str, Any] = { - "id": f"chatcmpl-{uuid.uuid4()}", - "object": "chat.completion", - "created": timestamp, - "model": chat_request.model, - "choices": [{"index": 0, "message": message_dict, "finish_reason": "stop"}], - "tee_signature": signature, - "tee_request_hash": input_hash_hex, - "tee_output_hash": output_hash_hex, - "tee_timestamp": timestamp, - "tee_id": f"0x{tee_keys.get_tee_id()}", - } - - usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - openai_response["usage"] = usage - cost = compute_session_cost(chat_request.model, usage, image_count=image_count) - if cost is not None: - openai_response["opengradient"] = cost.model_dump(mode="json") - - CreateChatCompletionResponse.from_dict(openai_response) - return openai_response - - -def _create_image_generation_streaming_response( - chat_request: CreateChatCompletionRequest, request_bytes: bytes -): - """Streaming image generation: image gen is not a token stream, so we invoke - once and emit the result on the final SSE frame (mirrors the Gemini path).""" - tee_keys = get_tee_keys() - - def generate(): - try: - langchain_messages = convert_messages(chat_request.messages) - prompt = _extract_image_prompt(langchain_messages) - reference_images = _extract_reference_images(langchain_messages) - images, image_count = generate_images( - chat_request.model, - prompt, - n=chat_request.n or 1, - reference_images=reference_images, - ) - - timestamp = int(time.time()) - msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( - request_bytes, "", timestamp - ) - tee_signature = tee_keys.sign_data(msg_hash) - - final_data: dict[str, Any] = { - "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], - "model": chat_request.model, - "tee_signature": tee_signature, - "tee_timestamp": timestamp, - "tee_request_hash": input_hash_hex, - "tee_output_hash": output_hash_hex, - "tee_id": f"0x{tee_keys.get_tee_id()}", - } - # Images travel out-of-band on the final frame; not part of the hash. - if images: - final_data["images"] = images - - usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - final_data["usage"] = usage - cost = compute_session_cost( - chat_request.model, usage, image_count=image_count - ) - if cost is not None: - final_data["opengradient"] = cost.model_dump(mode="json") - - yield f"data: {json.dumps(final_data)}\n\n" - yield "data: [DONE]\n\n" - except Exception as e: - logger.error(f"Image generation streaming error: {str(e)}", exc_info=True) - yield f"data: {json.dumps({'error': 'Stream processing failed', 'exception_type': type(e).__name__})}\n\n" - - return Response( - generate(), - mimetype="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): """Handle non-streaming chat completion via direct LangChain call.""" try: @@ -414,7 +233,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): # surfaced through this same endpoint, returning images out-of-band just # like Gemini's inline-image models. if cfg.image_generation: - return _create_image_generation_response(chat_request, request_bytes) + return create_image_generation_response(chat_request, request_bytes) model = get_chat_model_cached( model=chat_request.model, @@ -571,7 +390,7 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest): # Image-generation models (xAI Grok, ByteDance Seedream) use a dedicated # images endpoint; handle them without building a chat model. if get_model_config(chat_request.model).image_generation: - return _create_image_generation_streaming_response( + return create_image_generation_streaming_response( chat_request, request_bytes ) diff --git a/tee_gateway/image_generation.py b/tee_gateway/image_generation.py new file mode 100644 index 0000000..1e8b1a3 --- /dev/null +++ b/tee_gateway/image_generation.py @@ -0,0 +1,312 @@ +"""Endpoint-based image generation. + +xAI (Aurora), ByteDance (Seedream/Seedance) and Z.ai (GLM-Image) expose image +generation through a dedicated OpenAI-compatible ``POST /images/generations`` +endpoint rather than the chat path. This module owns everything specific to that +flow: + + * ``generate_images`` — shape and send the provider request, always returning + inline ``data:`` URIs (a provider-hosted URL is fetched into the enclave so + the client never sees a raw URL). + * ``create_image_generation_response`` / + ``create_image_generation_streaming_response`` — surface the result on + ``/v1/chat/completions`` exactly like Gemini's inline-image models (images + ride out-of-band under the message ``images`` key), signed and billed flat + per generated image. + +Per-provider request quirks (response format, ``n`` support, reference-image +editing, extra params) live in the model registry, keeping this code flat. +""" + +import base64 +import json +import logging +import time +import uuid +from typing import Any, List, Optional + +import httpx +from flask import Response + +from tee_gateway import llm_backend +from tee_gateway.models.create_chat_completion_request import ( + CreateChatCompletionRequest, +) +from tee_gateway.models.create_chat_completion_response import ( + CreateChatCompletionResponse, +) +from tee_gateway.model_registry import get_model_config +from tee_gateway.pricing import compute_session_cost +from tee_gateway.tee_manager import get_tee_keys, compute_tee_msg_hash + +logger = logging.getLogger(__name__) + +# Endpoint path appended to each provider's OpenAI-compatible base URL. +_IMAGE_GENERATION_PATH = "/images/generations" + +# Provider -> the shared HTTP client attribute on ``llm_backend`` (built after +# key injection). Looked up by attribute so a patched/rebuilt client is picked up. +_IMAGE_CLIENT_ATTRS = { + "x-ai": "xai_http_client", + "bytedance": "bytedance_http_client", + "zai": "zai_http_client", +} + +# Shared keyless client for fetching provider-hosted image URLs into the enclave. +_image_fetch_client: Optional[httpx.Client] = None + + +def _fetch_url_as_data_uri(url: str) -> str: + """Fetch a hosted image URL and return it as a ``data:`` URI. + + Lets URL-returning image providers be surfaced identically to inline-byte + ones: the bytes are pulled into the enclave and ride out inside the OHTTP/TEE + envelope, so the client never sees a raw provider URL. + """ + global _image_fetch_client + if _image_fetch_client is None: + _image_fetch_client = httpx.Client( + timeout=httpx.Timeout(timeout=120.0, connect=15.0), + follow_redirects=True, + ) + resp = _image_fetch_client.get(url) + resp.raise_for_status() + mime = (resp.headers.get("content-type") or "image/jpeg").split(";", 1)[0].strip() + b64 = base64.b64encode(resp.content).decode("ascii") + return f"data:{mime or 'image/jpeg'};base64,{b64}" + + +def generate_images( + model: str, + prompt: str, + n: int = 1, + reference_images: Optional[List[str]] = None, +) -> tuple[list[str], int]: + """Generate images via a provider's OpenAI-compatible images endpoint. + + ``reference_images`` carries input images for image-to-image editing (e.g. a + follow-up "add a hat" that builds on the previously generated image), sent on + the same endpoint via the ``image`` field (a URL or base64 ``data:`` URI, or + an array of up to 10). Models whose endpoint doesn't support it ignore them. + + Returns ``(data_uris, image_count)``. Every entry is a ``data:`` URI — when a + provider returns a hosted URL instead of inline bytes, the gateway fetches it + so the client always receives bytes. The count drives per-image billing. + """ + cfg = get_model_config(model) + provider = cfg.provider + + client_attr = _IMAGE_CLIENT_ATTRS.get(provider) + if client_attr is None: + raise ValueError( + f"Provider {provider!r} does not support the image-generation endpoint" + ) + client: Optional[httpx.Client] = getattr(llm_backend, client_attr, None) + if client is None: + raise RuntimeError(f"{provider} HTTP client has not been initialized") + + # n is clamped to the OpenAI-compatible providers' documented 1..10 range. + count = max(1, min(int(n), 10)) + payload: dict[str, Any] = {"model": cfg.api_name, "prompt": prompt} + if cfg.image_response_format is not None: + payload["response_format"] = cfg.image_response_format + if cfg.image_send_n: + payload["n"] = count + if cfg.image_extra_params: + payload.update(cfg.image_extra_params) + # Image-to-image editing: forward reference images via the ``image`` field (a + # single string, or an array for multi-reference edits, up to 10). Without + # this a follow-up edit like "add a hat" would ignore the prior image and + # generate a fresh one from the prompt text alone. + if reference_images and cfg.image_supports_reference: + refs = reference_images[:10] + payload["image"] = refs[0] if len(refs) == 1 else refs + + logger.info( + "Generating %d image(s) - Provider: %s, Model: %s", + count, + provider, + cfg.api_name, + ) + resp = client.post(_IMAGE_GENERATION_PATH, json=payload) + resp.raise_for_status() + data = resp.json().get("data", []) or [] + + images: list[str] = [] + for item in data: + if not isinstance(item, dict): + continue + b64 = item.get("b64_json") + if b64: + images.append(f"data:image/jpeg;base64,{b64}") + elif item.get("url"): + images.append(_fetch_url_as_data_uri(item["url"])) + + return images, len(images) + + +def _extract_image_inputs(langchain_messages: list) -> tuple[str, list[str]]: + """Pull the text prompt and any reference images out of the user turns. + + Image-generation models take a single text prompt rather than a chat + transcript, so we join the text of all human messages (system/assistant/tool + turns are ignored). A user turn doing an image-to-image edit carries the prior + image as an ``image_url``/``image`` content part alongside its text; we collect + those reference images separately so they can be forwarded to the provider — + without them a follow-up like "add a hat" would ignore the previous image and + generate from the prompt text alone. + + Returns ``(prompt, reference_images)``. + """ + from langchain_core.messages import HumanMessage + + text_parts: list[str] = [] + images: list[str] = [] + for m in langchain_messages: + if not isinstance(m, HumanMessage): + continue + content = m.content + if isinstance(content, str): + if content: + text_parts.append(content) + continue + if not isinstance(content, list): + continue + for part in content: + if isinstance(part, str): + if part: + text_parts.append(part) + continue + if not isinstance(part, dict): + continue + ptype = part.get("type") + if ptype == "text": + if part.get("text"): + text_parts.append(part["text"]) + elif ptype == "image_url": + image_url = part.get("image_url") + url = image_url.get("url") if isinstance(image_url, dict) else image_url + if url: + images.append(url) + elif ptype == "image": + # Standard LangChain image block: inline base64 + mime, or a url. + if part.get("base64"): + mime = part.get("mime_type") or "image/png" + images.append(f"data:{mime};base64,{part['base64']}") + elif part.get("url"): + images.append(part["url"]) + return "\n".join(text_parts), images + + +def _run_image_generation( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +) -> dict[str, Any]: + """Shared core for endpoint-based image generation. + + Generates the images (always inline bytes — the gateway fetches any + provider-hosted URL), signs the response, and computes billing. Image + generation has no text to sign, so the signature covers the request hash and + an empty output; the image bytes ride out-of-band inside the OHTTP envelope. + Billing is flat per generated image. Returns the pieces the streaming and + non-streaming responders assemble into their respective envelopes. + """ + langchain_messages = llm_backend.convert_messages(chat_request.messages) + prompt, reference_images = _extract_image_inputs(langchain_messages) + images, image_count = generate_images( + chat_request.model, + prompt, + n=chat_request.n or 1, + reference_images=reference_images, + ) + + timestamp = int(time.time()) + msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash( + request_bytes, "", timestamp + ) + tee_keys = get_tee_keys() + usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + cost = compute_session_cost(chat_request.model, usage, image_count=image_count) + + return { + "images": images, + "usage": usage, + "opengradient": cost.model_dump(mode="json") if cost is not None else None, + "tee_signature": tee_keys.sign_data(msg_hash), + "tee_request_hash": input_hash_hex, + "tee_output_hash": output_hash_hex, + "tee_timestamp": timestamp, + "tee_id": f"0x{tee_keys.get_tee_id()}", + } + + +def create_image_generation_response( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +): + """Non-streaming image generation via a provider's images endpoint. + + Surfaces generated images on the message under the ``images`` key exactly + like Gemini's inline-image models. + """ + result = _run_image_generation(chat_request, request_bytes) + + message_dict: dict[str, Any] = {"role": "assistant", "content": ""} + if result["images"]: + message_dict["images"] = result["images"] + + openai_response: dict[str, Any] = { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": result["tee_timestamp"], + "model": chat_request.model, + "choices": [{"index": 0, "message": message_dict, "finish_reason": "stop"}], + "tee_signature": result["tee_signature"], + "tee_request_hash": result["tee_request_hash"], + "tee_output_hash": result["tee_output_hash"], + "tee_timestamp": result["tee_timestamp"], + "tee_id": result["tee_id"], + "usage": result["usage"], + } + if result["opengradient"] is not None: + openai_response["opengradient"] = result["opengradient"] + + CreateChatCompletionResponse.from_dict(openai_response) + return openai_response + + +def create_image_generation_streaming_response( + chat_request: CreateChatCompletionRequest, request_bytes: bytes +): + """Streaming image generation: image gen is not a token stream, so we invoke + once and emit the result on the final SSE frame (mirrors the Gemini path).""" + + def generate(): + try: + result = _run_image_generation(chat_request, request_bytes) + + final_data: dict[str, Any] = { + "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], + "model": chat_request.model, + "tee_signature": result["tee_signature"], + "tee_timestamp": result["tee_timestamp"], + "tee_request_hash": result["tee_request_hash"], + "tee_output_hash": result["tee_output_hash"], + "tee_id": result["tee_id"], + "usage": result["usage"], + } + # Images travel out-of-band on the final frame; not part of the hash. + if result["images"]: + final_data["images"] = result["images"] + if result["opengradient"] is not None: + final_data["opengradient"] = result["opengradient"] + + yield f"data: {json.dumps(final_data)}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + logger.error(f"Image generation streaming error: {str(e)}", exc_info=True) + yield f"data: {json.dumps({'error': 'Stream processing failed', 'exception_type': type(e).__name__})}\n\n" + + return Response( + generate(), + mimetype="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) diff --git a/tee_gateway/llm_backend.py b/tee_gateway/llm_backend.py index ccbce83..8092a21 100644 --- a/tee_gateway/llm_backend.py +++ b/tee_gateway/llm_backend.py @@ -329,107 +329,6 @@ def get_chat_model_cached( raise ValueError(f"Unsupported provider: {provider}") -# Endpoint path appended to each provider's OpenAI-compatible base URL. -_IMAGE_GENERATION_PATH = "/images/generations" - - -def generate_images( - model: str, - prompt: str, - n: int = 1, - reference_images: Optional[List[str]] = None, -) -> tuple[list[str], int]: - """Generate images via a provider's OpenAI-compatible images endpoint. - - Unlike Gemini's inline-image chat models, xAI (Aurora), ByteDance - (Seedream), and Z.ai (GLM-Image) expose image generation through a dedicated - ``POST /images/generations`` endpoint. We request ``b64_json`` so the image - bytes ride inline inside the OHTTP/TEE envelope for providers that support - it. Z.ai returns temporary image URLs only. - - ``reference_images`` carries input images for image-to-image editing (e.g. a - follow-up "add a hat" that builds on the previously generated image). On - ByteDance Seedream/Seedance these ride on the same endpoint via the ``image`` - field, which accepts a URL or a base64 ``data:`` URI (or an array of them, up - to 10). Providers without image-edit support ignore the references. - - Returns ``(data_uris, image_count)`` where each entry is a ``data:`` URI. The - count is used for per-image billing. Falls back to provider-returned URLs if - a provider ignores ``b64_json``. - """ - cfg = get_model_config(model) - provider = cfg.provider - - if provider == "x-ai": - client = xai_http_client - elif provider == "bytedance": - client = bytedance_http_client - elif provider == "zai": - client = zai_http_client - else: - raise ValueError( - f"Provider {provider!r} does not support the image-generation endpoint" - ) - - if client is None: - raise RuntimeError(f"{provider} HTTP client has not been initialized") - - # n is clamped to the OpenAI-compatible providers' documented 1..10 range. - # Z.ai's GLM-Image and ByteDance Seedance endpoints don't document n/ - # response_format support, so keep their payloads to documented fields. - count = max(1, min(int(n), 10)) - payload: dict[str, Any] = { - "model": cfg.api_name, - "prompt": prompt, - } - if provider == "zai": - payload["size"] = "1280x1280" - elif provider == "bytedance" and cfg.api_name.startswith("ep-"): - # Seedance deployment endpoints use URL format and require extra params. - payload["response_format"] = "url" - payload["sequential_image_generation"] = "disabled" - payload["watermark"] = False - payload["size"] = "2K" - payload["stream"] = False - else: - payload["n"] = count - payload["response_format"] = "b64_json" - - # Image-to-image editing: ByteDance Seedream/Seedance accept reference images - # on the same endpoint via the ``image`` field (URL or base64 data URI; an - # array for multi-reference edits, up to 10). Without this, a follow-up edit - # like "add a hat" would silently ignore the prior image and generate a fresh - # one from the prompt text alone. Only ByteDance is known to support this; - # other providers' text-to-image endpoints reject unknown fields, so gate it. - if reference_images and provider == "bytedance": - refs = reference_images[:10] - payload["image"] = refs[0] if len(refs) == 1 else refs - - logger.info( - "Generating %d image(s) - Provider: %s, Model: %s", - count, - provider, - cfg.api_name, - ) - resp = client.post(_IMAGE_GENERATION_PATH, json=payload) - resp.raise_for_status() - data = resp.json().get("data", []) or [] - - images: list[str] = [] - for item in data: - if not isinstance(item, dict): - continue - b64 = item.get("b64_json") - if b64: - images.append(f"data:image/jpeg;base64,{b64}") - continue - url = item.get("url") - if url: - images.append(url) - - return images, len(images) - - def _parse_data_uri(uri: str) -> Optional[tuple[str, str]]: """Parse a ``data:;base64,`` URI into ``(mime_type, base64_data)``. diff --git a/tee_gateway/model_registry.py b/tee_gateway/model_registry.py index e2098d7..132db19 100644 --- a/tee_gateway/model_registry.py +++ b/tee_gateway/model_registry.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from decimal import Decimal from enum import Enum, unique -from typing import Optional +from typing import Any, Mapping, Optional @dataclass(frozen=True) @@ -35,6 +35,21 @@ class ModelConfig: # Flat USD price per generated image, for ``image_generation`` models. Token # prices are ignored for these models (set to 0 in the registry). per_image_price_usd: Optional[Decimal] = None + # ── /images/generations request shaping (image_generation models only) ── + # The ``response_format`` to request. ``"b64_json"`` returns inline bytes; + # ``"url"`` returns a hosted link the gateway fetches and inlines (so the + # client always receives bytes). ``None`` omits the field for endpoints that + # don't document it (Z.ai GLM-Image). + image_response_format: Optional[str] = "b64_json" + # Whether to send the OpenAI-style ``n`` count. Some endpoints (Z.ai + # GLM-Image, ByteDance Seedance) don't document it and reject/ignore it. + image_send_n: bool = True + # Whether the endpoint accepts reference images for image-to-image editing, + # sent via the ``image`` field. Text-to-image-only endpoints reject it. + image_supports_reference: bool = False + # Static extra params merged verbatim into the request payload (e.g. size, + # watermark). Keyed by field name; values must be JSON-serializable. + image_extra_params: Optional[Mapping[str, Any]] = None # USD per image-modality output token, for ``image_output`` models (Gemini # "nano banana"). These providers bill image output at a higher rate than # text/thinking output: image tokens at this rate, text + thinking tokens at @@ -375,10 +390,12 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.03"), + image_supports_reference=True, ) # Seedance 4.5 image generation via a ModelArk deployment endpoint. - # Uses URL response format and seedance-specific request params - # (sequential_image_generation, watermark, size). Billed per image. + # Returns hosted URLs (fetched and inlined by the gateway) and needs + # seedance-specific request params (sequential_image_generation, watermark, + # size). Billed per image. SEEDANCE_4_5 = ModelConfig( provider="bytedance", api_name="ep-20260624042612-7dxcv", @@ -386,6 +403,15 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.05"), + image_response_format="url", + image_send_n=False, + image_supports_reference=True, + image_extra_params={ + "sequential_image_generation": "disabled", + "watermark": False, + "size": "2K", + "stream": False, + }, ) # ── Nous Research (Nous Portal, OpenAI-compatible) ────────────────── @@ -418,6 +444,8 @@ class SupportedModel(Enum): output_price_usd=Decimal("0.0000044"), ) # GLM-Image uses Z.ai's image endpoint and is billed per generated image. + # Z.ai returns hosted URLs only (fetched and inlined by the gateway) and + # documents neither ``n`` nor ``response_format``, so both are omitted. GLM_IMAGE = ModelConfig( provider="zai", api_name="glm-image", @@ -425,6 +453,9 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.015"), + image_response_format=None, + image_send_n=False, + image_extra_params={"size": "1280x1280"}, ) # ── Legacy models (not in current SDK — retained for older SDK versions) ── diff --git a/tee_gateway/test/test_image_generation.py b/tee_gateway/test/test_image_generation.py index 48c7bfe..140a909 100644 --- a/tee_gateway/test/test_image_generation.py +++ b/tee_gateway/test/test_image_generation.py @@ -6,19 +6,19 @@ endpoint and billed a flat price per generated image. These tests pin: 1. The request/response handling in ``generate_images`` (b64_json -> data URI, - n clamping, url fallback, provider-specific payloads). + n clamping, hosted-URL fetch -> data URI, provider-specific payloads). 2. The flat per-image billing in ``compute_session_cost``. -No network or API key required — the provider HTTP client is mocked and a stub -price feed is injected. +No network or API key required — the provider HTTP client is mocked, the URL +fetch is patched, and a stub price feed is injected. """ import unittest from decimal import Decimal from unittest.mock import MagicMock, patch -from tee_gateway import llm_backend -from tee_gateway.llm_backend import generate_images +from tee_gateway import image_generation, llm_backend +from tee_gateway.image_generation import generate_images from tee_gateway.model_registry import get_model_config from tee_gateway.price_feed import get_price_feed, set_price_feed from tee_gateway.pricing import compute_session_cost @@ -64,23 +64,40 @@ def test_b64_json_becomes_data_uri(self): self.assertEqual(payload["n"], 2) self.assertEqual(payload["response_format"], "b64_json") - def test_url_fallback_when_no_b64(self): + def test_hosted_url_is_fetched_and_inlined(self): + # Providers that return a hosted URL instead of inline bytes are fetched + # into the enclave so the client always receives a data: URI. client = MagicMock() client.post.return_value = _mock_response([{"url": "https://img/1.jpg"}]) - with patch.object(llm_backend, "bytedance_http_client", client): + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ) as fetch, + ): images, count = generate_images(SEEDREAM, "a blue sphere", n=1) self.assertEqual(count, 1) - self.assertEqual(images, ["https://img/1.jpg"]) + self.assertEqual(images, ["data:image/jpeg;base64,RkVUQ0hFRA=="]) + fetch.assert_called_once_with("https://img/1.jpg") - def test_zai_glm_image_uses_documented_payload_and_url_response(self): + def test_zai_glm_image_uses_documented_payload_and_fetches_url(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://z.ai/img.png"}]) - with patch.object(llm_backend, "zai_http_client", client): + with ( + patch.object(llm_backend, "zai_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/png;base64,RkVUQ0hFRA==", + ), + ): images, count = generate_images(GLM_IMAGE, "a poster", n=3) self.assertEqual(count, 1) - self.assertEqual(images, ["https://z.ai/img.png"]) + self.assertEqual(images, ["data:image/png;base64,RkVUQ0hFRA=="]) _, kwargs = client.post.call_args payload = kwargs["json"] @@ -93,11 +110,18 @@ def test_zai_glm_image_uses_documented_payload_and_url_response(self): def test_seedance_uses_url_format_and_extra_params(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://cdn/img.jpg"}]) - with patch.object(llm_backend, "bytedance_http_client", client): + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ), + ): images, count = generate_images(SEEDANCE, "a black hole", n=1) self.assertEqual(count, 1) - self.assertEqual(images, ["https://cdn/img.jpg"]) + self.assertEqual(images, ["data:image/jpeg;base64,RkVUQ0hFRA=="]) _, kwargs = client.post.call_args payload = kwargs["json"] @@ -113,7 +137,14 @@ def test_seedance_uses_url_format_and_extra_params(self): def test_seedance_forwards_single_reference_image(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://cdn/edited.jpg"}]) - with patch.object(llm_backend, "bytedance_http_client", client): + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ), + ): generate_images( SEEDANCE, "add a hat", @@ -174,6 +205,38 @@ def test_uninitialized_client_raises(self): generate_images(GROK_IMAGE, "p", n=1) +class TestFetchUrlAsDataUri(unittest.TestCase): + """The hosted-URL fetch that inlines provider images into the enclave.""" + + def test_encodes_bytes_with_content_type(self): + fetch_client = MagicMock() + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.headers = {"content-type": "image/png; charset=binary"} + resp.content = b"hello" + fetch_client.get.return_value = resp + + with patch.object(image_generation, "_image_fetch_client", fetch_client): + uri = image_generation._fetch_url_as_data_uri("https://cdn/x.png") + + # base64("hello") == "aGVsbG8=", mime taken from content-type (params dropped) + self.assertEqual(uri, "data:image/png;base64,aGVsbG8=") + fetch_client.get.assert_called_once_with("https://cdn/x.png") + + def test_defaults_mime_when_header_missing(self): + fetch_client = MagicMock() + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.headers = {} + resp.content = b"hello" + fetch_client.get.return_value = resp + + with patch.object(image_generation, "_image_fetch_client", fetch_client): + uri = image_generation._fetch_url_as_data_uri("https://cdn/x") + + self.assertEqual(uri, "data:image/jpeg;base64,aGVsbG8=") + + class TestPerImageBilling(unittest.TestCase): """Flat per-image pricing, independent of token usage.""" From db557e76f8713a68cdeba733755e46d103bde2c9 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 24 Jun 2026 15:11:43 +0000 Subject: [PATCH 3/3] Address PR review: harden image URL fetch, latest-turn refs, ep- model fix Responds to the Copilot review on #107 plus a regression the main merge surfaced. image_generation.py: - _fetch_url_as_data_uri is now hardened against SSRF/egress abuse: http(s) schemes only, IP-literal hosts in private/loopback/link-local/non-global ranges rejected, redirects capped, and the body streamed with a 25 MiB cap (declared Content-Length and actual bytes both checked) instead of buffering an unbounded response. It is only ever called with provider-response URLs, never client input; the docstring now states that invariant. - _extract_image_inputs returns only the latest user turn's reference images (each turn replaces the set; a text-only turn clears it) so stale images from earlier edits don't pile up toward the provider's 10-image cap. Text and reference values are coerced/filtered to strings so malformed input can't break downstream JSON serialization. - generate_images filters reference_images to non-empty strings before clamping/forwarding. model_registry.py: - Fix a regression from merging main: Seedream 5.0 Lite (added in #109) is an "ep-" ModelArk deployment endpoint that relied on the old startswith("ep-") auto-detection for the URL/no-n/watermark payload, which the #110 refactor replaced with explicit fields. It was defaulting to b64_json+n. Both ep- models now share _BYTEDANCE_EP_IMAGE_PARAMS and the url/no-n/reference config. Tests: streaming-aware fetch tests + SSRF/size-cap coverage, _extract_image_inputs parsing/latest-turn/robustness tests, and a Seedream 5.0 Lite payload regression guard. 26 tests pass; ruff + mypy clean. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_01BHGhsd68znPWoooHvtLFD5 --- CLAUDE.md | 18 +- tee_gateway/image_generation.py | 163 +++++++++++++----- tee_gateway/model_registry.py | 30 +++- tee_gateway/test/test_image_generation.py | 201 ++++++++++++++++++++-- 4 files changed, 339 insertions(+), 73 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 24507c3..a841445 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -124,11 +124,19 @@ Model name prefixes determine routing: Image generation via xAI (grok-2-image), ByteDance (seedream-4.0, seedream-5.0-lite, seedance-4.5), and Z.ai (glm-image) is served through a -provider `/images/generations` endpoint rather than the chat path, but is -surfaced on `/v1/chat/completions` exactly like Gemini's inline-image models -(images returned out-of-band under the message `images` key). These models are -billed a flat per-image price (see `per_image_price_usd` in -`model_registry.py`), not per token. +provider `/images/generations` endpoint rather than the chat path (see +`image_generation.py`), but is surfaced on `/v1/chat/completions` exactly like +Gemini's inline-image models (images returned out-of-band under the message +`images` key). The client always receives inline bytes: providers that hand back +a hosted URL (Z.ai, Seedance, Seedream 5.0 Lite) are fetched inside the enclave +and inlined as `data:` URIs (the fetch is guarded: http(s) only, non-public IP +hosts rejected, redirects + size capped, and only ever called on provider- +response URLs, never client input). Image-to-image editing sends the prior image +back inline (a `data:` URI / `image_url` content part on the latest user turn), +forwarded to providers that support it via the endpoint's `image` field. +Per-provider request quirks (response format, `n`, size/watermark, reference +support) live in `model_registry.py`. These models are billed a flat per-image +price (see `per_image_price_usd`), not per token. ## Verification Examples diff --git a/tee_gateway/image_generation.py b/tee_gateway/image_generation.py index 1e8b1a3..f2841c5 100644 --- a/tee_gateway/image_generation.py +++ b/tee_gateway/image_generation.py @@ -19,11 +19,13 @@ """ import base64 +import ipaddress import json import logging import time import uuid from typing import Any, List, Optional +from urllib.parse import urlparse import httpx from flask import Response @@ -52,27 +54,85 @@ "zai": "zai_http_client", } +# Bounds on the URL fetch (egress hardening). Provider images are well under the +# size cap; the redirect cap stops a redirect chain from being chased off-host. +_MAX_IMAGE_BYTES = 25 * 1024 * 1024 # 25 MiB +_MAX_REDIRECTS = 3 +_ALLOWED_FETCH_SCHEMES = {"http", "https"} + # Shared keyless client for fetching provider-hosted image URLs into the enclave. _image_fetch_client: Optional[httpx.Client] = None +def _validate_fetch_url(url: str) -> None: + """Guard the image fetch against SSRF / egress abuse. + + Only http(s) URLs are allowed, and IP-literal hosts in private, loopback, + link-local, or otherwise non-global ranges are rejected. Hostnames are not + resolved here — this helper is only ever called with a URL from a provider's + own ``/images/generations`` response (never client input, which is forwarded + to the provider rather than dereferenced) — so the scheme/IP checks plus the + size and redirect caps in ``_fetch_url_as_data_uri`` bound the blast radius if + that trust is ever misplaced. + """ + parsed = urlparse(url) + if parsed.scheme not in _ALLOWED_FETCH_SCHEMES: + raise ValueError(f"Refusing to fetch image from non-http(s) URL: {url!r}") + host = parsed.hostname + if not host: + raise ValueError(f"Refusing to fetch image from URL without a host: {url!r}") + try: + ip = ipaddress.ip_address(host) + except ValueError: + return # hostname, not an IP literal — left to the provider's CDN + if not ip.is_global: + raise ValueError(f"Refusing to fetch image from non-public address: {host}") + + def _fetch_url_as_data_uri(url: str) -> str: - """Fetch a hosted image URL and return it as a ``data:`` URI. + """Fetch a provider-hosted image URL and return it as a ``data:`` URI. + + Only ever called with URLs from a provider's ``/images/generations`` response + — never with client-supplied input, which is forwarded to the provider rather + than dereferenced inside the enclave. Lets URL-returning providers be surfaced + identically to inline-byte ones: the bytes are pulled into the enclave and + ride out inside the OHTTP/TEE envelope, so the client never sees a raw URL. - Lets URL-returning image providers be surfaced identically to inline-byte - ones: the bytes are pulled into the enclave and ride out inside the OHTTP/TEE - envelope, so the client never sees a raw provider URL. + Hardened against egress abuse: http(s) only, non-public IP hosts rejected, + redirects capped, and the body capped at ``_MAX_IMAGE_BYTES`` (streamed so an + oversized response is aborted instead of buffered whole). """ + _validate_fetch_url(url) + global _image_fetch_client if _image_fetch_client is None: _image_fetch_client = httpx.Client( timeout=httpx.Timeout(timeout=120.0, connect=15.0), follow_redirects=True, + max_redirects=_MAX_REDIRECTS, ) - resp = _image_fetch_client.get(url) - resp.raise_for_status() - mime = (resp.headers.get("content-type") or "image/jpeg").split(";", 1)[0].strip() - b64 = base64.b64encode(resp.content).decode("ascii") + + with _image_fetch_client.stream("GET", url) as resp: + resp.raise_for_status() + declared = resp.headers.get("content-length") + if declared and declared.isdigit() and int(declared) > _MAX_IMAGE_BYTES: + raise ValueError( + f"Refusing to fetch image larger than {_MAX_IMAGE_BYTES} bytes" + ) + mime = ( + (resp.headers.get("content-type") or "image/jpeg").split(";", 1)[0].strip() + ) + chunks: list[bytes] = [] + total = 0 + for chunk in resp.iter_bytes(): + total += len(chunk) + if total > _MAX_IMAGE_BYTES: + raise ValueError( + f"Refusing to fetch image larger than {_MAX_IMAGE_BYTES} bytes" + ) + chunks.append(chunk) + + b64 = base64.b64encode(b"".join(chunks)).decode("ascii") return f"data:{mime or 'image/jpeg'};base64,{b64}" @@ -117,10 +177,12 @@ def generate_images( # Image-to-image editing: forward reference images via the ``image`` field (a # single string, or an array for multi-reference edits, up to 10). Without # this a follow-up edit like "add a hat" would ignore the prior image and - # generate a fresh one from the prompt text alone. + # generate a fresh one from the prompt text alone. Filter to non-empty strings + # so a malformed entry can't break JSON serialization of the request payload. if reference_images and cfg.image_supports_reference: - refs = reference_images[:10] - payload["image"] = refs[0] if len(refs) == 1 else refs + refs = [r for r in reference_images if isinstance(r, str) and r][:10] + if refs: + payload["image"] = refs[0] if len(refs) == 1 else refs logger.info( "Generating %d image(s) - Provider: %s, Model: %s", @@ -152,50 +214,67 @@ def _extract_image_inputs(langchain_messages: list) -> tuple[str, list[str]]: transcript, so we join the text of all human messages (system/assistant/tool turns are ignored). A user turn doing an image-to-image edit carries the prior image as an ``image_url``/``image`` content part alongside its text; we collect - those reference images separately so they can be forwarded to the provider — - without them a follow-up like "add a hat" would ignore the previous image and - generate from the prompt text alone. + those reference images so they can be forwarded to the provider — without them + a follow-up like "add a hat" would ignore the previous image and generate from + the prompt text alone. + + Only the *latest* user turn's references are returned: each user turn replaces + the running set (and a text-only turn clears it). Carrying references forward + from earlier edits would forward stale images and burn through the provider's + 10-image cap. Values are coerced/filtered to strings so malformed input can't + break downstream JSON serialization. Returns ``(prompt, reference_images)``. """ from langchain_core.messages import HumanMessage text_parts: list[str] = [] - images: list[str] = [] + reference_images: list[str] = [] for m in langchain_messages: if not isinstance(m, HumanMessage): continue content = m.content + # References reset per user turn so the latest turn is authoritative. + turn_images: list[str] = [] if isinstance(content, str): if content: text_parts.append(content) - continue - if not isinstance(content, list): - continue - for part in content: - if isinstance(part, str): - if part: - text_parts.append(part) - continue - if not isinstance(part, dict): - continue - ptype = part.get("type") - if ptype == "text": - if part.get("text"): - text_parts.append(part["text"]) - elif ptype == "image_url": - image_url = part.get("image_url") - url = image_url.get("url") if isinstance(image_url, dict) else image_url - if url: - images.append(url) - elif ptype == "image": - # Standard LangChain image block: inline base64 + mime, or a url. - if part.get("base64"): - mime = part.get("mime_type") or "image/png" - images.append(f"data:{mime};base64,{part['base64']}") - elif part.get("url"): - images.append(part["url"]) - return "\n".join(text_parts), images + elif isinstance(content, list): + for part in content: + if isinstance(part, str): + if part: + text_parts.append(part) + continue + if not isinstance(part, dict): + continue + ptype = part.get("type") + if ptype == "text": + text = part.get("text") + if text: + text_parts.append(str(text)) + elif ptype == "image_url": + image_url = part.get("image_url") + url = ( + image_url.get("url") + if isinstance(image_url, dict) + else image_url + ) + if isinstance(url, str) and url: + turn_images.append(url) + elif ptype == "image": + # Standard LangChain image block: inline base64 + mime, or url. + b64 = part.get("base64") + url = part.get("url") + if isinstance(b64, str) and b64: + mime = part.get("mime_type") + mime = mime if isinstance(mime, str) and mime else "image/png" + turn_images.append(f"data:{mime};base64,{b64}") + elif isinstance(url, str) and url: + turn_images.append(url) + else: + continue # unrecognized content shape: no text, leave references as-is + reference_images = turn_images + return "\n".join(text_parts), reference_images def _run_image_generation( diff --git a/tee_gateway/model_registry.py b/tee_gateway/model_registry.py index 6dd5164..700d40c 100644 --- a/tee_gateway/model_registry.py +++ b/tee_gateway/model_registry.py @@ -79,6 +79,17 @@ class ModelConfig: "google": Decimal("0.035"), } +# ByteDance ModelArk image *deployment* endpoints (api_name "ep-…", e.g. Seedance +# 4.5, Seedream 5.0 Lite) return the URL response format and require these extra +# params. The gateway fetches the returned URL and inlines the bytes, so the +# client still receives inline bytes. Shared so the two ep- models stay in sync. +_BYTEDANCE_EP_IMAGE_PARAMS: dict[str, Any] = { + "sequential_image_generation": "disabled", + "watermark": False, + "size": "2K", + "stream": False, +} + @unique class SupportedModel(Enum): @@ -401,6 +412,9 @@ class SupportedModel(Enum): image_supports_reference=True, ) # Seedream 5.0 Lite image generation via a ModelArk deployment endpoint. + # Seedream 5.0 Lite image generation via a ModelArk deployment endpoint + # (api_name "ep-…"). Like Seedance it returns hosted URLs (fetched and inlined + # by the gateway) and takes the shared ep- deployment params. Billed per image. SEEDREAM_5_0_LITE = ModelConfig( provider="bytedance", api_name="ep-20260624213657-7zc5n", @@ -408,11 +422,14 @@ class SupportedModel(Enum): output_price_usd=Decimal("0"), image_generation=True, per_image_price_usd=Decimal("0.035"), + image_response_format="url", + image_send_n=False, + image_supports_reference=True, + image_extra_params=_BYTEDANCE_EP_IMAGE_PARAMS, ) # Seedance 4.5 image generation via a ModelArk deployment endpoint. - # Returns hosted URLs (fetched and inlined by the gateway) and needs - # seedance-specific request params (sequential_image_generation, watermark, - # size). Billed per image. + # Returns hosted URLs (fetched and inlined by the gateway) and takes the + # shared ep- deployment params. Billed per image. SEEDANCE_4_5 = ModelConfig( provider="bytedance", api_name="ep-20260624042612-7dxcv", @@ -423,12 +440,7 @@ class SupportedModel(Enum): image_response_format="url", image_send_n=False, image_supports_reference=True, - image_extra_params={ - "sequential_image_generation": "disabled", - "watermark": False, - "size": "2K", - "stream": False, - }, + image_extra_params=_BYTEDANCE_EP_IMAGE_PARAMS, ) # ── Nous Research (Nous Portal, OpenAI-compatible) ────────────────── diff --git a/tee_gateway/test/test_image_generation.py b/tee_gateway/test/test_image_generation.py index 140a909..c6fb573 100644 --- a/tee_gateway/test/test_image_generation.py +++ b/tee_gateway/test/test_image_generation.py @@ -25,6 +25,7 @@ GROK_IMAGE = "grok-2-image" SEEDREAM = "seedream-4.0" +SEEDREAM_5_LITE = "seedream-5.0-lite" SEEDANCE = "seedance-4.5" GLM_IMAGE = "glm-image" @@ -134,6 +135,33 @@ def test_seedance_uses_url_format_and_extra_params(self): self.assertFalse(payload["stream"]) self.assertNotIn("n", payload) + def test_seedream_5_lite_uses_ep_deployment_params(self): + # Seedream 5.0 Lite is an ep- deployment endpoint and must use the same + # URL/no-n/seedance-style payload as Seedance — a regression guard, since + # this used to be auto-detected from the "ep-" api_name prefix and is now + # driven by explicit registry fields. + client = MagicMock() + client.post.return_value = _mock_response([{"url": "https://cdn/img.jpg"}]) + with ( + patch.object(llm_backend, "bytedance_http_client", client), + patch.object( + image_generation, + "_fetch_url_as_data_uri", + return_value="data:image/jpeg;base64,RkVUQ0hFRA==", + ), + ): + images, count = generate_images(SEEDREAM_5_LITE, "a koi pond", n=1) + + self.assertEqual(images, ["data:image/jpeg;base64,RkVUQ0hFRA=="]) + payload = client.post.call_args.kwargs["json"] + self.assertEqual(payload["model"], get_model_config(SEEDREAM_5_LITE).api_name) + self.assertEqual(payload["response_format"], "url") + self.assertEqual(payload["sequential_image_generation"], "disabled") + self.assertFalse(payload["watermark"]) + self.assertEqual(payload["size"], "2K") + self.assertFalse(payload["stream"]) + self.assertNotIn("n", payload) + def test_seedance_forwards_single_reference_image(self): client = MagicMock() client.post.return_value = _mock_response([{"url": "https://cdn/edited.jpg"}]) @@ -205,37 +233,176 @@ def test_uninitialized_client_raises(self): generate_images(GROK_IMAGE, "p", n=1) +def _mock_stream_client(headers: dict, chunks: list[bytes]) -> MagicMock: + """A fake httpx client whose .stream(...) yields a response with these bytes.""" + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.headers = headers + resp.iter_bytes.return_value = chunks + ctx = MagicMock() + ctx.__enter__.return_value = resp + ctx.__exit__.return_value = False + client = MagicMock() + client.stream.return_value = ctx + return client + + class TestFetchUrlAsDataUri(unittest.TestCase): """The hosted-URL fetch that inlines provider images into the enclave.""" def test_encodes_bytes_with_content_type(self): - fetch_client = MagicMock() - resp = MagicMock() - resp.raise_for_status.return_value = None - resp.headers = {"content-type": "image/png; charset=binary"} - resp.content = b"hello" - fetch_client.get.return_value = resp - - with patch.object(image_generation, "_image_fetch_client", fetch_client): + client = _mock_stream_client( + {"content-type": "image/png; charset=binary"}, [b"hel", b"lo"] + ) + with patch.object(image_generation, "_image_fetch_client", client): uri = image_generation._fetch_url_as_data_uri("https://cdn/x.png") # base64("hello") == "aGVsbG8=", mime taken from content-type (params dropped) self.assertEqual(uri, "data:image/png;base64,aGVsbG8=") - fetch_client.get.assert_called_once_with("https://cdn/x.png") + client.stream.assert_called_once_with("GET", "https://cdn/x.png") def test_defaults_mime_when_header_missing(self): - fetch_client = MagicMock() - resp = MagicMock() - resp.raise_for_status.return_value = None - resp.headers = {} - resp.content = b"hello" - fetch_client.get.return_value = resp - - with patch.object(image_generation, "_image_fetch_client", fetch_client): + client = _mock_stream_client({}, [b"hello"]) + with patch.object(image_generation, "_image_fetch_client", client): uri = image_generation._fetch_url_as_data_uri("https://cdn/x") self.assertEqual(uri, "data:image/jpeg;base64,aGVsbG8=") + def test_rejects_non_http_scheme(self): + # Validation happens before any client use, so no client is needed. + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri("ftp://cdn/x.png") + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri("file:///etc/passwd") + + def test_rejects_private_and_loopback_ip_hosts(self): + for url in ( + "http://127.0.0.1/x.png", + "http://169.254.169.254/latest/meta-data", # cloud metadata + "http://10.0.0.5/x.png", + "http://[::1]/x.png", + ): + with self.subTest(url=url): + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri(url) + + def test_rejects_body_exceeding_size_cap(self): + client = _mock_stream_client( + {"content-type": "image/png"}, [b"x" * 4, b"x" * 4] + ) + with ( + patch.object(image_generation, "_image_fetch_client", client), + patch.object(image_generation, "_MAX_IMAGE_BYTES", 5), + ): + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri("https://cdn/big.png") + + def test_rejects_declared_content_length_over_cap(self): + client = _mock_stream_client( + {"content-type": "image/png", "content-length": "999"}, [b"x"] + ) + with ( + patch.object(image_generation, "_image_fetch_client", client), + patch.object(image_generation, "_MAX_IMAGE_BYTES", 5), + ): + with self.assertRaises(ValueError): + image_generation._fetch_url_as_data_uri("https://cdn/big.png") + + +class TestExtractImageInputs(unittest.TestCase): + """Prompt + reference-image extraction from the user turns.""" + + @staticmethod + def _human(content): + from langchain_core.messages import HumanMessage + + return HumanMessage(content=content) + + def test_joins_text_across_turns_no_references(self): + msgs = [self._human("a red cube"), self._human("make it blue")] + prompt, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(prompt, "a red cube\nmake it blue") + self.assertEqual(refs, []) + + def test_mixed_text_and_image_does_not_splice_base64_into_prompt(self): + # An image-to-image edit turn: text + an attached reference image. The + # base64 blob must never leak into the prompt text. + data_uri = "data:image/png;base64,QUJD" + msgs = [ + self._human( + [ + {"type": "text", "text": "add a hat"}, + {"type": "image_url", "image_url": {"url": data_uri}}, + ] + ) + ] + prompt, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(prompt, "add a hat") + self.assertEqual(refs, [data_uri]) + + def test_extracts_langchain_image_block_with_base64_and_url(self): + msgs = [ + self._human( + [ + {"type": "image", "base64": "QUJD", "mime_type": "image/webp"}, + {"type": "image", "url": "https://cdn/ref.jpg"}, + ] + ) + ] + _, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(refs, ["data:image/webp;base64,QUJD", "https://cdn/ref.jpg"]) + + def test_only_latest_turn_references_are_returned(self): + # An earlier edit turn carried an image; the latest turn carries a new + # one. Only the latest turn's reference should be forwarded. + msgs = [ + self._human( + [ + {"type": "text", "text": "first"}, + {"type": "image_url", "image_url": {"url": "https://cdn/old.jpg"}}, + ] + ), + self._human( + [ + {"type": "text", "text": "second"}, + {"type": "image_url", "image_url": {"url": "https://cdn/new.jpg"}}, + ] + ), + ] + prompt, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(prompt, "first\nsecond") + self.assertEqual(refs, ["https://cdn/new.jpg"]) + + def test_text_only_latest_turn_clears_stale_references(self): + # Edit turn with an image, then a plain text follow-up: the text-only + # latest turn means a fresh generation, so no stale reference rides along. + msgs = [ + self._human( + [ + {"type": "text", "text": "first"}, + {"type": "image_url", "image_url": {"url": "https://cdn/old.jpg"}}, + ] + ), + self._human("just text now"), + ] + _, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(refs, []) + + def test_malformed_image_parts_are_ignored(self): + msgs = [ + self._human( + [ + {"type": "text", "text": "p"}, + {"type": "image_url", "image_url": {"url": None}}, + {"type": "image_url", "image_url": 123}, + {"type": "image", "base64": None}, + ] + ) + ] + prompt, refs = image_generation._extract_image_inputs(msgs) + self.assertEqual(prompt, "p") + self.assertEqual(refs, []) + class TestPerImageBilling(unittest.TestCase): """Flat per-image pricing, independent of token usage."""