Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ API keys (injected at runtime via POST /v1/keys — do NOT bake into the image):
- `XAI_API_KEY`
- `ARK_API_KEY` (BytePlus / ByteDance ModelArk; injected as `bytedance_api_key`)
- `NOUS_API_KEY` (Nous Research / Nous Portal; injected as `nous_api_key`)
- `ZAI_API_KEY` (Z.ai Model API; injected as `zai_api_key`)

Server configuration:
- `API_SERVER_PORT` (default: 8000)
Expand Down Expand Up @@ -118,13 +119,14 @@ Model name prefixes determine routing:
- **xAI**: grok-2, grok-3, grok-3-mini, grok-4, grok-4-fast, grok-4-1-fast; image generation: grok-2-image
- **ByteDance** (BytePlus ModelArk, OpenAI-compatible, ap-southeast): seed-1.6, seed-1.8, seed-2.0-lite; image generation: seedream-4.0
- **Nous Research** (Nous Portal, OpenAI-compatible): hermes-4-405b, hermes-4-70b

Image generation via xAI (grok-2-image) and ByteDance (seedream-4.0) is served
through a provider `/images/generations` endpoint rather than the chat path, but
is surfaced on `/v1/chat/completions` exactly like Gemini's inline-image models
(images returned out-of-band under the message `images` key). These models are
billed a flat per-image price (see `per_image_price_usd` in `model_registry.py`),
not per token.
- **Z.ai** (Model API, OpenAI-compatible): glm-5.2; image generation: glm-image

Image generation via xAI (grok-2-image), ByteDance (seedream-4.0), and Z.ai
(glm-image) is served through a provider `/images/generations` endpoint rather
than the chat path, but is surfaced on `/v1/chat/completions` exactly like
Gemini's inline-image models (images returned out-of-band under the message
`images` key). These models are billed a flat per-image price (see
`per_image_price_usd` in `model_registry.py`), not per token.

## Verification Examples

Expand Down
6 changes: 5 additions & 1 deletion scripts/run-enclave.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ if [ -f "$ENV_FILE" ]; then
XAI_API_KEY="$(grep -E '^XAI_API_KEY=' "$ENV_FILE" | cut -d'=' -f2-)"
ARK_API_KEY="$(grep -E '^ARK_API_KEY=' "$ENV_FILE" | cut -d'=' -f2-)"
NOUS_API_KEY="$(grep -E '^NOUS_API_KEY=' "$ENV_FILE" | cut -d'=' -f2-)"
ZAI_API_KEY="$(grep -E '^ZAI_API_KEY=' "$ENV_FILE" | cut -d'=' -f2-)"

# FACILITATOR_URL is used for both x402 payment verification and the heartbeat relay.
# HEARTBEAT_CONTRACT_ADDRESS and TEE_HEARTBEAT_INTERVAL are optional heartbeat parameters.
Expand All @@ -108,6 +109,7 @@ if [ -f "$ENV_FILE" ]; then
--arg xai "$XAI_API_KEY" \
--arg bytedance "$ARK_API_KEY" \
--arg nous "$NOUS_API_KEY" \
--arg zai "$ZAI_API_KEY" \
--arg hb_contract "$HEARTBEAT_CONTRACT_ADDRESS" \
--arg facilitator "$FACILITATOR_URL" \
--arg hb_interval "$TEE_HEARTBEAT_INTERVAL" \
Expand All @@ -117,7 +119,8 @@ if [ -f "$ENV_FILE" ]; then
anthropic_api_key: $anthropic,
xai_api_key: $xai,
bytedance_api_key: $bytedance,
nous_api_key: $nous
nous_api_key: $nous,
zai_api_key: $zai
}
+ if $hb_contract != "" then {heartbeat_contract_address: $hb_contract} else {} end
+ if $facilitator != "" then {facilitator_url: $facilitator} else {} end
Expand Down Expand Up @@ -148,6 +151,7 @@ if [ -f "$ENV_FILE" ]; then

# Clear key variables from this shell immediately after use
unset OPENAI_API_KEY GOOGLE_API_KEY ANTHROPIC_API_KEY XAI_API_KEY ARK_API_KEY
unset NOUS_API_KEY ZAI_API_KEY
unset HEARTBEAT_CONTRACT_ADDRESS FACILITATOR_URL TEE_HEARTBEAT_INTERVAL
fi
else
Expand Down
5 changes: 5 additions & 0 deletions tee_gateway/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def set_provider_keys():
xai_api_key=body.get("xai_api_key") or None,
bytedance_api_key=body.get("bytedance_api_key") or None,
nous_api_key=body.get("nous_api_key") or None,
zai_api_key=body.get("zai_api_key") or None,
)
set_provider_config(provider_config)

Expand Down Expand Up @@ -456,6 +457,9 @@ def _set(val: str | None) -> str:
logger.info(
" nous_api_key : %s", _set(provider_config.nous_api_key)
)
logger.info(
" zai_api_key : %s", _set(provider_config.zai_api_key)
)
logger.info(" facilitator_url : %s", facilitator_url)
logger.info(
" heartbeat_contract_address : %s",
Expand Down Expand Up @@ -489,6 +493,7 @@ def _set(val: str | None) -> str:
"xai": provider_config.xai_api_key,
"bytedance": provider_config.bytedance_api_key,
"nous": provider_config.nous_api_key,
"zai": provider_config.zai_api_key,
}.items()
if k
]
Expand Down
3 changes: 3 additions & 0 deletions tee_gateway/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ProviderConfig:
xai_api_key: Optional[str] = None
bytedance_api_key: Optional[str] = None
nous_api_key: Optional[str] = None
zai_api_key: Optional[str] = None

def initialized_providers(self) -> list[str]:
"""Return provider names whose API key is set (non-empty)."""
Expand All @@ -45,6 +46,8 @@ def initialized_providers(self) -> list[str]:
providers.append("bytedance")
if self.nous_api_key:
providers.append("nous")
if self.zai_api_key:
providers.append("zai")
return providers


Expand Down
60 changes: 52 additions & 8 deletions tee_gateway/llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,19 @@
# Nous Research OpenAI-compatible inference endpoint (Nous Portal).
NOUS_BASE_URL = "https://inference-api.nousresearch.com/v1"

# Z.ai Model API OpenAI-compatible endpoint. The full chat URL is
# https://api.z.ai/api/paas/v4/chat/completions; ChatOpenAI appends
# /chat/completions to this base URL. Do not confuse this paid Model API with
# the subscription Coding Plan endpoint at /api/coding/paas/v4.
ZAI_BASE_URL = "https://api.z.ai/api/paas/v4"

# Shared synchronous HTTP clients for each provider.
# Initialized to None; built by set_provider_config() after key injection.
openai_http_client: Optional[httpx.Client] = None
xai_http_client: Optional[httpx.Client] = None
bytedance_http_client: Optional[httpx.Client] = None
nous_http_client: Optional[httpx.Client] = None
zai_http_client: Optional[httpx.Client] = None


_provider_config: Optional[ProviderConfig] = None
Expand All @@ -67,12 +74,13 @@
def set_provider_config(config: ProviderConfig) -> None:
"""Store the provider config and rebuild HTTP clients. Called once after key injection."""
global _provider_config, openai_http_client, xai_http_client, bytedance_http_client
global nous_http_client
global nous_http_client, zai_http_client

old_openai = openai_http_client
old_xai = xai_http_client
old_bytedance = bytedance_http_client
old_nous = nous_http_client
old_zai = zai_http_client

openai_http_client = httpx.Client(
base_url="https://api.openai.com/v1",
Expand Down Expand Up @@ -106,6 +114,14 @@ def set_provider_config(config: ProviderConfig) -> None:
http2=True,
follow_redirects=False,
)
zai_http_client = httpx.Client(
base_url=ZAI_BASE_URL,
headers={"Authorization": f"Bearer {config.zai_api_key or ''}"},
timeout=_TIMEOUT,
limits=_LIMITS,
http2=True,
follow_redirects=False,
)

get_chat_model_cached.cache_clear()
_provider_config = config
Expand All @@ -118,6 +134,8 @@ def set_provider_config(config: ProviderConfig) -> None:
old_bytedance.close()
if old_nous is not None:
old_nous.close()
if old_zai is not None:
old_zai.close()


def get_provider_config() -> Optional[ProviderConfig]:
Expand Down Expand Up @@ -289,6 +307,24 @@ def get_chat_model_cached(
stream_usage=True,
) # type: ignore [call-arg]

elif provider == "zai":
if not config.zai_api_key:
raise ValueError("zai_api_key not set in ProviderConfig")

if zai_http_client is None:
raise RuntimeError("Z.ai HTTP client has not been initialized")

return ChatOpenAI(
model=api_name,
temperature=effective_temp,
max_tokens=max_tokens,
http_client=zai_http_client,
api_key=SecretStr(config.zai_api_key),
base_url=ZAI_BASE_URL,
streaming=True,
stream_usage=True,
) # type: ignore [call-arg]

else:
raise ValueError(f"Unsupported provider: {provider}")

Expand All @@ -300,11 +336,11 @@ def get_chat_model_cached(
def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int]:
"""Generate images via a provider's OpenAI-compatible images endpoint.

Unlike Gemini's inline-image chat models, xAI (Aurora) and ByteDance
(Seedream) expose image generation through a dedicated
Unlike Gemini's inline-image chat models, xAI (Aurora), ByteDance
(Seedream), and Z.ai (GLM-Image) expose image generation through a dedicated
``POST /images/generations`` endpoint. We request ``b64_json`` so the image
bytes ride inline inside the OHTTP/TEE envelope rather than as external URLs
that leak the content and expire.
bytes ride inline inside the OHTTP/TEE envelope for providers that support
it. Z.ai returns temporary image URLs only.

Returns ``(data_uris, image_count)`` where each entry is a ``data:`` URI. The
count is used for per-image billing. Falls back to provider-returned URLs if
Expand All @@ -317,6 +353,8 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int
client = xai_http_client
elif provider == "bytedance":
client = bytedance_http_client
elif provider == "zai":
client = zai_http_client
else:
raise ValueError(
f"Provider {provider!r} does not support the image-generation endpoint"
Expand All @@ -325,14 +363,20 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int
if client is None:
raise RuntimeError(f"{provider} HTTP client has not been initialized")

# n is clamped to the providers' documented 1..10 range.
# n is clamped to the OpenAI-compatible providers' documented 1..10 range.
# Z.ai's GLM-Image endpoint currently returns exactly one image and does not
# document n/response_format support, so keep its payload to the documented
# fields.
count = max(1, min(int(n), 10))
payload: dict[str, Any] = {
"model": cfg.api_name,
"prompt": prompt,
"n": count,
"response_format": "b64_json",
}
if provider == "zai":
payload["size"] = "1280x1280"
else:
payload["n"] = count
payload["response_format"] = "b64_json"

logger.info(
"Generating %d image(s) - Provider: %s, Model: %s",
Expand Down
24 changes: 23 additions & 1 deletion tee_gateway/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

@dataclass(frozen=True)
class ModelConfig:
provider: str # "openai" | "anthropic" | "google" | "x-ai" | "bytedance" | "nous"
# "openai" | "anthropic" | "google" | "x-ai" | "bytedance" | "nous" | "zai"
provider: str
api_name: str # model name sent to provider API
input_price_usd: Decimal # USD per token
output_price_usd: Decimal # USD per token
Expand Down Expand Up @@ -385,6 +386,24 @@ class SupportedModel(Enum):
output_price_usd=Decimal("0.0000004"),
)

# ── Z.ai (Model API, OpenAI-compatible) ─────────────────────────────
# Z.ai publishes GLM-5.2 prices per 1M tokens: $1.40 input, $4.40 output.
GLM_5_2 = ModelConfig(
provider="zai",
api_name="glm-5.2",
input_price_usd=Decimal("0.0000014"),
output_price_usd=Decimal("0.0000044"),
)
# GLM-Image uses Z.ai's image endpoint and is billed per generated image.
GLM_IMAGE = ModelConfig(
provider="zai",
api_name="glm-image",
input_price_usd=Decimal("0"),
output_price_usd=Decimal("0"),
image_generation=True,
per_image_price_usd=Decimal("0.015"),
)

# ── Legacy models (not in current SDK — retained for older SDK versions) ──
GROK_3_MINI = ModelConfig(
provider="x-ai",
Expand Down Expand Up @@ -465,6 +484,9 @@ class SupportedModel(Enum):
# Nous Research
"hermes-4-405b": SupportedModel.HERMES_4_405B,
"hermes-4-70b": SupportedModel.HERMES_4_70B,
# Z.ai
"glm-5.2": SupportedModel.GLM_5_2,
"glm-image": SupportedModel.GLM_IMAGE,
# Legacy — not in current SDK, retained for older SDK versions
"grok-3-mini-beta": SupportedModel.GROK_3_MINI, # old beta alias
"grok-3-mini": SupportedModel.GROK_3_MINI,
Expand Down
25 changes: 22 additions & 3 deletions tee_gateway/test/test_image_generation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Tests for endpoint-based image generation (xAI Grok, ByteDance Seedream).
"""Tests for endpoint-based image generation (xAI Grok, ByteDance Seedream,
Z.ai GLM-Image).

Unlike Gemini's inline-image chat models (see test_image_billing.py), these
models are served via a dedicated OpenAI-compatible ``/images/generations``
endpoint and billed a flat price per generated image. These tests pin:

1. The request/response handling in ``generate_images`` (b64_json -> data URI,
n clamping, url fallback).
n clamping, url fallback, provider-specific payloads).
2. The flat per-image billing in ``compute_session_cost``.

No network or API key required — the provider HTTP client is mocked and a stub
Expand All @@ -24,6 +25,7 @@

GROK_IMAGE = "grok-2-image"
SEEDREAM = "seedream-4.0"
GLM_IMAGE = "glm-image"


def _mock_response(data: list[dict]) -> MagicMock:
Expand Down Expand Up @@ -70,6 +72,23 @@ def test_url_fallback_when_no_b64(self):
self.assertEqual(count, 1)
self.assertEqual(images, ["https://img/1.jpg"])

def test_zai_glm_image_uses_documented_payload_and_url_response(self):
client = MagicMock()
client.post.return_value = _mock_response([{"url": "https://z.ai/img.png"}])
with patch.object(llm_backend, "zai_http_client", client):
images, count = generate_images(GLM_IMAGE, "a poster", n=3)

self.assertEqual(count, 1)
self.assertEqual(images, ["https://z.ai/img.png"])

_, kwargs = client.post.call_args
payload = kwargs["json"]
self.assertEqual(payload["model"], "glm-image")
self.assertEqual(payload["prompt"], "a poster")
self.assertEqual(payload["size"], "1280x1280")
self.assertNotIn("n", payload)
self.assertNotIn("response_format", payload)

def test_n_is_clamped_to_provider_range(self):
client = MagicMock()
client.post.return_value = _mock_response([{"b64_json": "x"}])
Expand Down Expand Up @@ -109,7 +128,7 @@ def _zero_usage() -> dict:
return {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}

def test_single_image_charged_flat_price(self):
for model in (GROK_IMAGE, SEEDREAM):
for model in (GROK_IMAGE, SEEDREAM, GLM_IMAGE):
with self.subTest(model=model):
cfg = get_model_config(model)
cost = compute_session_cost(model, self._zero_usage(), image_count=1)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,22 @@ def test_hermes_4_70b_resolves(self):
self.assertEqual(cfg.input_price_usd, Decimal("0.00000013"))
self.assertEqual(cfg.output_price_usd, Decimal("0.0000004"))

# ── Z.ai (Model API) ───────────────────────────────────────────────────

def test_glm_5_2_resolves(self):
cfg = get_model_config("glm-5.2")
self.assertEqual(cfg.provider, "zai")
self.assertEqual(cfg.api_name, "glm-5.2")
self.assertEqual(cfg.input_price_usd, Decimal("0.0000014"))
self.assertEqual(cfg.output_price_usd, Decimal("0.0000044"))

def test_glm_image_resolves(self):
cfg = get_model_config("glm-image")
self.assertEqual(cfg.provider, "zai")
self.assertEqual(cfg.api_name, "glm-image")
self.assertTrue(cfg.image_generation)
self.assertEqual(cfg.per_image_price_usd, Decimal("0.015"))

# ── Errors ───────────────────────────────────────────────────────────────

def test_unknown_model_raises(self):
Expand Down
Loading