Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The repo must provide a stable AWS Nitro PCR when the code doesn't change in ord
├── tee_gateway/ # Main application package (Flask/connexion)
│ ├── __main__.py # Entry point: app factory, x402 middleware setup, key injection
│ ├── llm_backend.py # LLM provider routing via LangChain, HTTP client management
│ ├── image_generation.py # Endpoint-based image gen (/images/generations): request shaping, URL→inline-bytes, signed responses
│ ├── tee_manager.py # TEE key generation, nitriding registration, response signing
│ ├── model_registry.py # Model config and per-token pricing
│ ├── definitions.py # On-chain addresses, network IDs, payment amounts
Expand Down Expand Up @@ -123,10 +124,16 @@ Model name prefixes determine routing:

Image generation via xAI (grok-2-image), ByteDance (seedream-4.0, seedance-4.5), and Z.ai
(glm-image) is served through a provider `/images/generations` endpoint rather
than the chat path, but is surfaced on `/v1/chat/completions` exactly like
Gemini's inline-image models (images returned out-of-band under the message
`images` key). These models are billed a flat per-image price (see
`per_image_price_usd` in `model_registry.py`), not per token.
than the chat path (see `image_generation.py`), but is surfaced on
`/v1/chat/completions` exactly like Gemini's inline-image models (images returned
out-of-band under the message `images` key). The client always receives inline
bytes: providers that hand back a hosted URL (Z.ai, Seedance) are fetched inside
the enclave and inlined as `data:` URIs. Image-to-image editing sends the prior
image back inline (a `data:` URI / `image_url` content part on the user turn),
forwarded to providers that support it via the endpoint's `image` field.
Per-provider request quirks (response format, `n`, size/watermark, reference
support) live in `model_registry.py`. These models are billed a flat per-image
price (see `per_image_price_usd`), not per token.

## Verification Examples

Expand Down
193 changes: 6 additions & 187 deletions tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
extract_web_search_count,
convert_messages,
extract_usage,
generate_images,
validate_attachments,
AttachmentValidationError,
canonical_user_content,
)
from tee_gateway.image_generation import (
create_image_generation_response,
create_image_generation_streaming_response,
)
from tee_gateway.model_registry import get_model_config
from tee_gateway.pricing import compute_session_cost

Expand Down Expand Up @@ -77,77 +80,6 @@ def _split_text_and_images(content: Any) -> tuple[str, list[str]]:
return ("".join(text_parts), images)


def _extract_image_prompt(langchain_messages: list) -> str:
"""Collapse the user-turn text into a single image-generation prompt.

Image-generation models (xAI Grok, ByteDance Seedream) take a single text
prompt rather than a chat transcript, so we join the text of all human
messages. System/assistant/tool turns are ignored.

A user turn carrying an attached reference image has list (multimodal)
content; we pull out only its ``text`` parts. (Naively stringifying the list
would splice the base64 image blob into the prompt — the reference image is
forwarded separately via ``_extract_reference_images``.)
"""
from langchain_core.messages import HumanMessage

parts: list[str] = []
for m in langchain_messages:
if not isinstance(m, HumanMessage):
continue
content = m.content
if isinstance(content, str):
if content:
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict):
if part.get("type") == "text" and part.get("text"):
parts.append(part["text"])
elif isinstance(part, str) and part:
parts.append(part)
return "\n".join(parts)


def _extract_reference_images(langchain_messages: list) -> list[str]:
"""Collect reference-image URLs/data-URIs from the user turns.

Endpoint-based image models (Seedream/Seedance) support image-to-image
editing: the client attaches the prior generated image (or an uploaded one)
to the latest user turn as an ``image_url`` content part. We pull those out
so they can be forwarded to the provider as reference images — without them a
follow-up like "add a hat" ignores the previous image and generates from the
prompt text alone. In practice only the latest user turn carries images, so
collecting across turns just yields the active references.
"""
from langchain_core.messages import HumanMessage

images: list[str] = []
for m in langchain_messages:
if not isinstance(m, HumanMessage):
continue
content = m.content
if not isinstance(content, list):
continue
for part in content:
if not isinstance(part, dict):
continue
ptype = part.get("type")
if ptype == "image_url":
image_url = part.get("image_url")
url = image_url.get("url") if isinstance(image_url, dict) else image_url
if url:
images.append(url)
elif ptype == "image":
# Standard LangChain image block: inline base64 + mime, or a url.
if part.get("base64"):
mime = part.get("mime_type") or "image/png"
images.append(f"data:{mime};base64,{part['base64']}")
elif part.get("url"):
images.append(part["url"])
return images


def create_chat_completion(body):
"""Create a chat completion (streaming or non-streaming)."""
if not connexion.request.is_json:
Expand Down Expand Up @@ -282,119 +214,6 @@ def _messages_contain_json_word(messages: list) -> bool:
return False


def _create_image_generation_response(
chat_request: CreateChatCompletionRequest, request_bytes: bytes
):
"""Non-streaming image generation via a provider's images endpoint.

Surfaces generated images on the message under the ``images`` key exactly
like Gemini's inline-image models. There is no text to sign, so the signature
covers the request hash and an empty output; the image bytes ride out-of-band
inside the OHTTP envelope. Billing is flat per generated image.
"""
langchain_messages = convert_messages(chat_request.messages)
prompt = _extract_image_prompt(langchain_messages)
reference_images = _extract_reference_images(langchain_messages)
images, image_count = generate_images(
chat_request.model,
prompt,
n=chat_request.n or 1,
reference_images=reference_images,
)

message_dict: dict[str, Any] = {"role": "assistant", "content": ""}
if images:
message_dict["images"] = images

timestamp = int(time.time())
msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash(
request_bytes, "", timestamp
)
tee_keys = get_tee_keys()
signature = tee_keys.sign_data(msg_hash)

openai_response: dict[str, Any] = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": timestamp,
"model": chat_request.model,
"choices": [{"index": 0, "message": message_dict, "finish_reason": "stop"}],
"tee_signature": signature,
"tee_request_hash": input_hash_hex,
"tee_output_hash": output_hash_hex,
"tee_timestamp": timestamp,
"tee_id": f"0x{tee_keys.get_tee_id()}",
}

usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
openai_response["usage"] = usage
cost = compute_session_cost(chat_request.model, usage, image_count=image_count)
if cost is not None:
openai_response["opengradient"] = cost.model_dump(mode="json")

CreateChatCompletionResponse.from_dict(openai_response)
return openai_response


def _create_image_generation_streaming_response(
chat_request: CreateChatCompletionRequest, request_bytes: bytes
):
"""Streaming image generation: image gen is not a token stream, so we invoke
once and emit the result on the final SSE frame (mirrors the Gemini path)."""
tee_keys = get_tee_keys()

def generate():
try:
langchain_messages = convert_messages(chat_request.messages)
prompt = _extract_image_prompt(langchain_messages)
reference_images = _extract_reference_images(langchain_messages)
images, image_count = generate_images(
chat_request.model,
prompt,
n=chat_request.n or 1,
reference_images=reference_images,
)

timestamp = int(time.time())
msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash(
request_bytes, "", timestamp
)
tee_signature = tee_keys.sign_data(msg_hash)

final_data: dict[str, Any] = {
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}],
"model": chat_request.model,
"tee_signature": tee_signature,
"tee_timestamp": timestamp,
"tee_request_hash": input_hash_hex,
"tee_output_hash": output_hash_hex,
"tee_id": f"0x{tee_keys.get_tee_id()}",
}
# Images travel out-of-band on the final frame; not part of the hash.
if images:
final_data["images"] = images

usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
final_data["usage"] = usage
cost = compute_session_cost(
chat_request.model, usage, image_count=image_count
)
if cost is not None:
final_data["opengradient"] = cost.model_dump(mode="json")

yield f"data: {json.dumps(final_data)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Image generation streaming error: {str(e)}", exc_info=True)
yield f"data: {json.dumps({'error': 'Stream processing failed', 'exception_type': type(e).__name__})}\n\n"

return Response(
generate(),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)


def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
"""Handle non-streaming chat completion via direct LangChain call."""
try:
Expand All @@ -414,7 +233,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
# surfaced through this same endpoint, returning images out-of-band just
# like Gemini's inline-image models.
if cfg.image_generation:
return _create_image_generation_response(chat_request, request_bytes)
return create_image_generation_response(chat_request, request_bytes)

model = get_chat_model_cached(
model=chat_request.model,
Expand Down Expand Up @@ -571,7 +390,7 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest):
# Image-generation models (xAI Grok, ByteDance Seedream) use a dedicated
# images endpoint; handle them without building a chat model.
if get_model_config(chat_request.model).image_generation:
return _create_image_generation_streaming_response(
return create_image_generation_streaming_response(
chat_request, request_bytes
)

Expand Down
Loading
Loading