Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The repo must provide a stable AWS Nitro PCR when the code doesn't change in ord
├── tee_gateway/ # Main application package (Flask/connexion)
│ ├── __main__.py # Entry point: app factory, x402 middleware setup, key injection
│ ├── llm_backend.py # LLM provider routing via LangChain, HTTP client management
│ ├── image_generation.py # Endpoint-based image gen (/images/generations): request shaping, URL→inline-bytes, signed responses
│ ├── tee_manager.py # TEE key generation, nitriding registration, response signing
│ ├── model_registry.py # Model config and per-token pricing
│ ├── definitions.py # On-chain addresses, network IDs, payment amounts
Expand Down Expand Up @@ -123,11 +124,19 @@ Model name prefixes determine routing:

Image generation via xAI (grok-2-image), ByteDance (seedream-4.0,
seedream-5.0-lite, seedance-4.5), and Z.ai (glm-image) is served through a
provider `/images/generations` endpoint rather than the chat path, but is
surfaced on `/v1/chat/completions` exactly like Gemini's inline-image models
(images returned out-of-band under the message `images` key). These models are
billed a flat per-image price (see `per_image_price_usd` in
`model_registry.py`), not per token.
provider `/images/generations` endpoint rather than the chat path (see
`image_generation.py`), but is surfaced on `/v1/chat/completions` exactly like
Gemini's inline-image models (images returned out-of-band under the message
`images` key). The client always receives inline bytes: providers that hand back
a hosted URL (Z.ai, Seedance, Seedream 5.0 Lite) are fetched inside the enclave
and inlined as `data:` URIs (the fetch is guarded: http(s) only, non-public IP
hosts rejected, redirects + size capped, and only ever called on provider-
response URLs, never client input). Image-to-image editing sends the prior image
back inline (a `data:` URI / `image_url` content part on the latest user turn),
forwarded to providers that support it via the endpoint's `image` field.
Per-provider request quirks (response format, `n`, size/watermark, reference
support) live in `model_registry.py`. These models are billed a flat per-image
price (see `per_image_price_usd`), not per token.

## Verification Examples

Expand Down
131 changes: 6 additions & 125 deletions tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
extract_web_search_count,
convert_messages,
extract_usage,
generate_images,
validate_attachments,
AttachmentValidationError,
canonical_user_content,
)
from tee_gateway.image_generation import (
create_image_generation_response,
create_image_generation_streaming_response,
)
from tee_gateway.model_registry import get_model_config
from tee_gateway.pricing import compute_session_cost

Expand Down Expand Up @@ -77,23 +80,6 @@ def _split_text_and_images(content: Any) -> tuple[str, list[str]]:
return ("".join(text_parts), images)


def _extract_image_prompt(langchain_messages: list) -> str:
"""Collapse the user-turn text into a single image-generation prompt.

Image-generation models (xAI Grok, ByteDance Seedream) take a single text
prompt rather than a chat transcript, so we join the text of all human
messages. System/assistant/tool turns are ignored.
"""
from langchain_core.messages import HumanMessage

parts: list[str] = []
for m in langchain_messages:
if isinstance(m, HumanMessage):
content = m.content
parts.append(content if isinstance(content, str) else str(content))
return "\n".join(p for p in parts if p)


def create_chat_completion(body):
"""Create a chat completion (streaming or non-streaming)."""
if not connexion.request.is_json:
Expand Down Expand Up @@ -228,111 +214,6 @@ def _messages_contain_json_word(messages: list) -> bool:
return False


def _create_image_generation_response(
chat_request: CreateChatCompletionRequest, request_bytes: bytes
):
"""Non-streaming image generation via a provider's images endpoint.

Surfaces generated images on the message under the ``images`` key exactly
like Gemini's inline-image models. There is no text to sign, so the signature
covers the request hash and an empty output; the image bytes ride out-of-band
inside the OHTTP envelope. Billing is flat per generated image.
"""
langchain_messages = convert_messages(chat_request.messages)
prompt = _extract_image_prompt(langchain_messages)
images, image_count = generate_images(
chat_request.model, prompt, n=chat_request.n or 1
)

message_dict: dict[str, Any] = {"role": "assistant", "content": ""}
if images:
message_dict["images"] = images

timestamp = int(time.time())
msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash(
request_bytes, "", timestamp
)
tee_keys = get_tee_keys()
signature = tee_keys.sign_data(msg_hash)

openai_response: dict[str, Any] = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": timestamp,
"model": chat_request.model,
"choices": [{"index": 0, "message": message_dict, "finish_reason": "stop"}],
"tee_signature": signature,
"tee_request_hash": input_hash_hex,
"tee_output_hash": output_hash_hex,
"tee_timestamp": timestamp,
"tee_id": f"0x{tee_keys.get_tee_id()}",
}

usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
openai_response["usage"] = usage
cost = compute_session_cost(chat_request.model, usage, image_count=image_count)
if cost is not None:
openai_response["opengradient"] = cost.model_dump(mode="json")

CreateChatCompletionResponse.from_dict(openai_response)
return openai_response


def _create_image_generation_streaming_response(
chat_request: CreateChatCompletionRequest, request_bytes: bytes
):
"""Streaming image generation: image gen is not a token stream, so we invoke
once and emit the result on the final SSE frame (mirrors the Gemini path)."""
tee_keys = get_tee_keys()

def generate():
try:
langchain_messages = convert_messages(chat_request.messages)
prompt = _extract_image_prompt(langchain_messages)
images, image_count = generate_images(
chat_request.model, prompt, n=chat_request.n or 1
)

timestamp = int(time.time())
msg_hash, input_hash_hex, output_hash_hex = compute_tee_msg_hash(
request_bytes, "", timestamp
)
tee_signature = tee_keys.sign_data(msg_hash)

final_data: dict[str, Any] = {
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}],
"model": chat_request.model,
"tee_signature": tee_signature,
"tee_timestamp": timestamp,
"tee_request_hash": input_hash_hex,
"tee_output_hash": output_hash_hex,
"tee_id": f"0x{tee_keys.get_tee_id()}",
}
# Images travel out-of-band on the final frame; not part of the hash.
if images:
final_data["images"] = images

usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
final_data["usage"] = usage
cost = compute_session_cost(
chat_request.model, usage, image_count=image_count
)
if cost is not None:
final_data["opengradient"] = cost.model_dump(mode="json")

yield f"data: {json.dumps(final_data)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Image generation streaming error: {str(e)}", exc_info=True)
yield f"data: {json.dumps({'error': 'Stream processing failed', 'exception_type': type(e).__name__})}\n\n"

return Response(
generate(),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)


def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
"""Handle non-streaming chat completion via direct LangChain call."""
try:
Expand All @@ -352,7 +233,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
# surfaced through this same endpoint, returning images out-of-band just
# like Gemini's inline-image models.
if cfg.image_generation:
return _create_image_generation_response(chat_request, request_bytes)
return create_image_generation_response(chat_request, request_bytes)

model = get_chat_model_cached(
model=chat_request.model,
Expand Down Expand Up @@ -509,7 +390,7 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest):
# Image-generation models (xAI Grok, ByteDance Seedream) use a dedicated
# images endpoint; handle them without building a chat model.
if get_model_config(chat_request.model).image_generation:
return _create_image_generation_streaming_response(
return create_image_generation_streaming_response(
chat_request, request_bytes
)

Expand Down
Loading
Loading