Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ Gemini's inline-image models (images returned out-of-band under the message
`images` key). These models are billed a flat per-image price (see
`per_image_price_usd` in `model_registry.py`), not per token.

Image output resolution is selected with a single `quality` request field —
`low` / `medium` / `high` — which each model maps to its own native size via
`image_quality_sizes` in `model_registry.py` (Seedream/Seedance/Gemini image:
`1K`/`2K`/`4K`; Z.ai GLM-Image: `1024`/`1280`/`2048` px). Models with no
resolution control (xAI Grok, Gemini 2.5 Flash Image) ignore it; `medium`
mirrors each model's previous default, so omitting `quality` is unchanged.

## Verification Examples

- `examples/verify_attestation.py` — Validates AWS Nitro attestation documents against the root CA
Expand Down
25 changes: 22 additions & 3 deletions tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
AttachmentValidationError,
canonical_user_content,
)
from tee_gateway.model_registry import get_model_config
from tee_gateway.model_registry import get_model_config, VALID_IMAGE_QUALITIES
from tee_gateway.pricing import compute_session_cost

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,6 +112,17 @@ def create_chat_completion(body):
except AttachmentValidationError as e:
return {"error": "Invalid attachment", "message": str(e)}, 400

# Validate the optional image-quality tier up front for a clean 400. Models
# without resolution control silently ignore it (handled downstream).
if chat_request.quality is not None and (
not isinstance(chat_request.quality, str)
or chat_request.quality.strip().lower() not in VALID_IMAGE_QUALITIES
):
return {
"error": "Invalid quality",
"message": (f"quality must be one of {', '.join(VALID_IMAGE_QUALITIES)}."),
}, 400

if chat_request.stream:
return _create_streaming_response(chat_request)
else:
Expand Down Expand Up @@ -241,7 +252,7 @@ def _create_image_generation_response(
langchain_messages = convert_messages(chat_request.messages)
prompt = _extract_image_prompt(langchain_messages)
images, image_count = generate_images(
chat_request.model, prompt, n=chat_request.n or 1
chat_request.model, prompt, n=chat_request.n or 1, quality=chat_request.quality
)

message_dict: dict[str, Any] = {"role": "assistant", "content": ""}
Expand Down Expand Up @@ -290,7 +301,10 @@ def generate():
langchain_messages = convert_messages(chat_request.messages)
prompt = _extract_image_prompt(langchain_messages)
images, image_count = generate_images(
chat_request.model, prompt, n=chat_request.n or 1
chat_request.model,
prompt,
n=chat_request.n or 1,
quality=chat_request.quality,
)

timestamp = int(time.time())
Expand Down Expand Up @@ -361,6 +375,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
else 0.0,
max_tokens=chat_request.max_tokens or 4096,
web_search=bool(chat_request.web_search),
image_quality=chat_request.quality,
)

# Bind user tools and/or the native web search tool if requested.
Expand Down Expand Up @@ -520,6 +535,7 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest):
else 0.0,
max_tokens=chat_request.max_tokens or 4096,
web_search=bool(chat_request.web_search),
image_quality=chat_request.quality,
)

# Bind user tools and/or the native web search tool if requested.
Expand Down Expand Up @@ -1002,6 +1018,8 @@ def _chat_request_to_dict(chat_request: CreateChatCompletionRequest) -> dict:
d["response_format"] = _normalize_response_format(chat_request.response_format)
if chat_request.web_search:
d["web_search"] = True
if chat_request.quality:
d["quality"] = chat_request.quality
return d


Expand All @@ -1025,6 +1043,7 @@ def _parse_chat_request(chat_request_dict: dict) -> CreateChatCompletionRequest:
tool_choice=chat_request_dict.get("tool_choice"),
user=chat_request_dict.get("user"),
web_search=chat_request_dict.get("web_search", False),
quality=chat_request_dict.get("quality"),
)


Expand Down
44 changes: 37 additions & 7 deletions tee_gateway/llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from langchain_xai import ChatXAI

from tee_gateway.config import ProviderConfig
from tee_gateway.model_registry import get_model_config
from tee_gateway.model_registry import get_model_config, resolve_image_size

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -150,17 +150,26 @@ def get_provider_from_model(model: str) -> str:

@lru_cache(maxsize=64)
def get_chat_model_cached(
model: str, temperature: float, max_tokens: int, web_search: bool = False
model: str,
temperature: float,
max_tokens: int,
web_search: bool = False,
image_quality: Optional[str] = None,
):
"""Get cached chat model instance using the injected ProviderConfig.

Models are cached by (model, temperature, max_tokens, web_search) tuple.
Cache is cleared by set_provider_config() after key injection.
Models are cached by (model, temperature, max_tokens, web_search,
image_quality) tuple. Cache is cleared by set_provider_config() after key
injection.

When ``web_search`` is True, provider-specific native web search is enabled.
Some providers (OpenAI, xAI) require search configuration at construction
time; others (Anthropic, Google) enable it by binding a tool — see
``get_web_search_tool``. Providers without native web search ignore the flag.

``image_quality`` ("low" | "medium" | "high") selects the output resolution
for inline image-output models (Gemini "nano banana"); it is ignored by
text models and by image models without resolution control.
"""
config = _provider_config
if config is None:
Expand All @@ -183,12 +192,18 @@ def get_chat_model_cached(
# thinking budget; ask for both TEXT and IMAGE modalities so the model
# may caption alongside the generated image.
if cfg.image_output:
# Map the requested quality tier to the model's resolution. Models
# without resolution control (e.g. Gemini 2.5 Flash Image) return
# None and use the provider default.
image_size = resolve_image_size(model, image_quality)
image_config = {"image_size": image_size} if image_size else None
return ChatGoogleGenerativeAI(
model=api_name,
google_api_key=config.google_api_key,
temperature=effective_temp,
max_output_tokens=max_tokens,
response_modalities=[Modality.TEXT, Modality.IMAGE],
image_config=image_config,
)

return ChatGoogleGenerativeAI(
Expand Down Expand Up @@ -333,7 +348,9 @@ def get_chat_model_cached(
_IMAGE_GENERATION_PATH = "/images/generations"


def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int]:
def generate_images(
model: str, prompt: str, n: int = 1, quality: Optional[str] = None
) -> tuple[list[str], int]:
"""Generate images via a provider's OpenAI-compatible images endpoint.

Unlike Gemini's inline-image chat models, xAI (Aurora), ByteDance
Expand All @@ -342,13 +359,22 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int
bytes ride inline inside the OHTTP/TEE envelope for providers that support
it. Z.ai returns temporary image URLs only.

``quality`` ("low" | "medium" | "high") selects the output resolution where
the model supports it (Seedream/Seedance: 1K/2K/4K, Z.ai: pixel dimensions).
xAI Grok exposes no resolution control and ignores the option.

Returns ``(data_uris, image_count)`` where each entry is a ``data:`` URI. The
count is used for per-image billing. Falls back to provider-returned URLs if
a provider ignores ``b64_json``.
"""
cfg = get_model_config(model)
provider = cfg.provider

# Map the requested quality tier to a provider-specific size string. None
# means "no quality requested" or "model has no resolution control"; in both
# cases the provider default below is left in place.
size_override = resolve_image_size(model, quality)

if provider == "x-ai":
client = xai_http_client
elif provider == "bytedance":
Expand All @@ -372,17 +398,21 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int
"prompt": prompt,
}
if provider == "zai":
payload["size"] = "1280x1280"
payload["size"] = size_override or "1280x1280"
elif provider == "bytedance" and cfg.api_name.startswith("ep-"):
# Seedance deployment endpoints use URL format and require extra params.
payload["response_format"] = "url"
payload["sequential_image_generation"] = "disabled"
payload["watermark"] = False
payload["size"] = "2K"
payload["size"] = size_override or "2K"
payload["stream"] = False
else:
payload["n"] = count
payload["response_format"] = "b64_json"
# Seedream accepts a size preset; xAI Grok has no resolution control, so
# size_override is None there and the provider default is used.
if size_override:
payload["size"] = size_override

logger.info(
"Generating %d image(s) - Provider: %s, Model: %s",
Expand Down
50 changes: 50 additions & 0 deletions tee_gateway/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ class ModelConfig:
# means "use the provider default" (see WEB_SEARCH_PRICE_USD_BY_PROVIDER);
# set an explicit value here to override a single model's web-search price.
web_search_price_usd: Optional[Decimal] = None
# For image models (``image_generation`` endpoint or ``image_output`` inline),
# maps the simple quality tier — ``"low"`` | ``"medium"`` | ``"high"`` — to the
# provider-specific size value sent to the API (e.g. "1K"/"2K"/"4K" for
# ByteDance/Gemini, "1024x1024" pixel dims for Z.ai). ``None`` means the model
# exposes no resolution control and the ``quality`` request option is ignored
# (e.g. xAI Grok and Gemini 2.5 Flash Image, which only emit one resolution).
# The ``"medium"`` tier mirrors each model's previous default size, so omitting
# ``quality`` and passing ``"medium"`` produce the same request.
image_quality_sizes: Optional[dict[str, str]] = None


# Default per-search USD price charged when a model uses native web search.
Expand Down Expand Up @@ -266,6 +275,8 @@ class SupportedModel(Enum):
output_price_usd=Decimal("0.000003"),
image_output=True,
image_output_price_usd=Decimal("0.00006"),
# Nano Banana 2 supports 1K/2K/4K via image_config.image_size (default 1K).
image_quality_sizes={"low": "1K", "medium": "2K", "high": "4K"},
)
GEMINI_3_5_FLASH = ModelConfig(
provider="google",
Expand Down Expand Up @@ -375,6 +386,9 @@ class SupportedModel(Enum):
output_price_usd=Decimal("0"),
image_generation=True,
per_image_price_usd=Decimal("0.03"),
# ModelArk accepts 1K/2K/4K size presets (one dimension fixed to
# 1024/2048/4096); pixels stay within the documented 1280x720..4096x4096.
image_quality_sizes={"low": "1K", "medium": "2K", "high": "4K"},
)
# Seedance 4.5 image generation via a ModelArk deployment endpoint.
# Uses URL response format and seedance-specific request params
Expand All @@ -386,6 +400,9 @@ class SupportedModel(Enum):
output_price_usd=Decimal("0"),
image_generation=True,
per_image_price_usd=Decimal("0.05"),
# Seedance accepts the same 1K/2K/4K size presets as Seedream; medium (2K)
# matches the endpoint's previous hardcoded default.
image_quality_sizes={"low": "1K", "medium": "2K", "high": "4K"},
)

# ── Nous Research (Nous Portal, OpenAI-compatible) ──────────────────
Expand Down Expand Up @@ -425,6 +442,13 @@ class SupportedModel(Enum):
output_price_usd=Decimal("0"),
image_generation=True,
per_image_price_usd=Decimal("0.015"),
# Z.ai takes explicit pixel dimensions (each 512..2048, multiple of 32);
# medium (1280x1280) matches the endpoint's previous hardcoded default.
image_quality_sizes={
"low": "1024x1024",
"medium": "1280x1280",
"high": "2048x2048",
},
)

# ── Legacy models (not in current SDK — retained for older SDK versions) ──
Expand Down Expand Up @@ -568,3 +592,29 @@ def get_web_search_price_usd(model: str) -> Decimal:
def provider_supports_web_search(provider: str) -> bool:
"""Whether the given provider has native web search the gateway can enable."""
return provider in WEB_SEARCH_PRICE_USD_BY_PROVIDER


# The simple quality tiers a caller may request for image generation. Each model
# maps these to its own provider-specific resolution (see `image_quality_sizes`).
VALID_IMAGE_QUALITIES = ("low", "medium", "high")


def resolve_image_size(model: str, quality: Optional[str]) -> Optional[str]:
"""Map a requested ``quality`` tier to a model's provider-specific size string.

Returns ``None`` when no quality is requested or the model exposes no
resolution control, in which case the provider default is left in place.
Raises ValueError on an unknown model or an invalid quality value.
"""
if quality is None:
return None
normalized = quality.strip().lower()
if normalized not in VALID_IMAGE_QUALITIES:
raise ValueError(
f"Unsupported quality: {quality!r}. "
f"Must be one of {', '.join(VALID_IMAGE_QUALITIES)}."
)
cfg = get_model_config(model)
if not cfg.image_quality_sizes:
return None
return cfg.image_quality_sizes.get(normalized)
3 changes: 3 additions & 0 deletions tee_gateway/models/create_chat_completion_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
function_call=None,
functions=None,
web_search=False,
quality=None,
):
self.messages = messages
self.model = model
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
self.function_call = function_call
self.functions = functions
self.web_search = web_search
self.quality = quality

@classmethod
def from_dict(cls, dikt) -> "CreateChatCompletionRequest":
Expand Down Expand Up @@ -103,5 +105,6 @@ def from_dict(cls, dikt) -> "CreateChatCompletionRequest":
"function_call",
"functions",
"web_search",
"quality",
}
return cls(**{k: v for k, v in dikt.items() if k in known})
16 changes: 16 additions & 0 deletions tee_gateway/openapi/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2895,6 +2895,22 @@ components:
nullable: true
title: web_search
type: boolean
quality:
description: |
Output resolution for image-generation models, as a simple tier.
Each model maps the tier to its own native resolution: ByteDance
Seedream/Seedance and Gemini image models use `1K`/`2K`/`4K`; Z.ai
GLM-Image uses pixel dimensions (1024/1280/2048 px). Models with no
resolution control (xAI Grok, Gemini 2.5 Flash Image) ignore this
option. Has no effect on text models. When omitted, each model's
default resolution (equivalent to `medium`) is used.
enum:
- low
- medium
- high
nullable: true
title: quality
type: string
store:
default: false
description: "Whether or not to store the output of this chat completion\
Expand Down
Loading
Loading