Skip to content
Merged
33 changes: 33 additions & 0 deletions src/app/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from config.settings import (
JWT_CLAIMS_CACHE_MAX_SIZE,
JWT_CLAIMS_CACHE_TTL_SECONDS,
OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP,
RBAC_CACHE_BACKEND,
RBAC_PERMISSION_CACHE_TTL_SECONDS,
UVICORN_WORKER_COUNT,
clients,
get_openrag_config,
get_openrag_service_token,
)
from services.startup_orchestrator import startup_tasks
from utils.logging_config import get_logger
Expand Down Expand Up @@ -199,6 +201,37 @@ async def run_startup(app: FastAPI):
await mcp_lifespan_ctx.__aenter__()
logger.info("FastMCP lifespan started")

# One-shot OpenSearch security bootstrap driven by the platform's
# service JWT. Runs synchronously (before startup_tasks) so the
# admin role mapping is in place before any other startup work
# talks to OpenSearch. The corresponding call inside startup_tasks
# is suppressed when this flag is on.
if OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP:
service_token = get_openrag_service_token()
if not service_token:
raise RuntimeError(
"OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP is enabled but "
"OPENRAG_SERVICE_TOKEN is not set"
)
from auth.ibm_auth import admin_username_from_service_jwt
from utils.opensearch_init import wait_for_opensearch
from utils.opensearch_utils import setup_opensearch_security

admin_username = admin_username_from_service_jwt(service_token)
if not admin_username:
raise RuntimeError(
"OPENRAG_SERVICE_TOKEN has no 'username' or 'sub' claim; "
"cannot bootstrap OpenSearch security"
)
opensearch_client = clients.create_opensearch_client_from_jwt(service_token)
try:
await wait_for_opensearch(opensearch_client)
logger.info("Bootstrapping OpenSearch security", admin_username=admin_username)
await setup_opensearch_security(opensearch_client, admin_username=admin_username)
logger.info("OpenSearch security bootstrap completed", admin_username=admin_username)
finally:
await opensearch_client.close()

# Start index initialization in background to avoid blocking OIDC endpoints
t1 = asyncio.create_task(startup_tasks(services))
app.state.background_tasks.add(t1)
Expand Down
24 changes: 24 additions & 0 deletions src/auth/ibm_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,30 @@ def decode_ibm_jwt(token: str) -> dict | None:
return None


def admin_username_from_service_jwt(token: str) -> str | None:
"""Return the admin username carried by a platform-issued service JWT.

Decodes *token* unsigned (the platform issues it; we only parse claims)
and returns `username` if present, falling back to `sub`. Matches the
claim precedence used by the auth dependency in dependencies.py.
Returns None if the token cannot be decoded or has neither claim.
"""
try:
claims = jwt.decode(token, options={"verify_signature": False})
except jwt.InvalidTokenError as exc:
logger.warning("Service JWT decode failed", error=str(exc))
Comment on lines +56 to +66
return None
value = claims.get("username") or claims.get("sub")
if not isinstance(value, str):
if value is not None:
logger.warning(
"Service JWT username/sub claim is not a string",
claim_type=type(value).__name__,
)
return None
return value

Comment thread
coderabbitai[bot] marked this conversation as resolved.

Comment on lines +55 to +78
async def fetch_ibm_public_key(url: str):
"""Fetch IBM's JWT public key PEM from *url* and cache it."""
global _cached_public_key
Expand Down
28 changes: 26 additions & 2 deletions src/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
IBM_AUTH_ENABLED = os.getenv("IBM_AUTH_ENABLED", "false").lower() in ("true", "1", "yes")
PLATFORM_USERNAME = os.getenv("PLATFORM_USERNAME")
PLATFORM_PASSWORD = os.getenv("PLATFORM_PASSWORD")
OPENRAG_TENANT_ID = os.getenv("OPENRAG_TENANT_ID", "openrag")
IBM_JWT_PUBLIC_KEY_URL = os.getenv("IBM_JWT_PUBLIC_KEY_URL", "")
IBM_SESSION_COOKIE_NAME = os.getenv("IBM_SESSION_COOKIE_NAME", "ibm-openrag-session")
IBM_CREDENTIALS_HEADER = os.getenv("IBM_CREDENTIALS_HEADER", "X-IBM-LH-Credentials")
Expand Down Expand Up @@ -107,6 +108,13 @@ def get_role_claim_viewer() -> str | None:
return os.getenv("OPENRAG_ROLE_CLAIM_VIEWER")


def get_openrag_service_token() -> str | None:
"""Platform-issued service JWT used at startup to bootstrap the OpenSearch
security context (admin role mapping). Read per-call — like the JWT-claim
accessors above — so runtime/test overrides take effect without a restart."""
return os.getenv("OPENRAG_SERVICE_TOKEN")


def get_jwt_auth_header() -> str:
"""HTTP header that may carry a gateway-forwarded JWT for /v1 (API-key)
callers. Read per-call so tests can override via monkeypatch.setenv."""
Expand Down Expand Up @@ -149,6 +157,18 @@ def _resolve_skip_os_security_default() -> str:
"OPENRAG_SKIP_OS_SECURITY_SETUP", _resolve_skip_os_security_default()
).lower() in ("true", "1", "yes")

# Run setup_opensearch_security once during FastAPI lifespan startup,
# using the admin username derived from OPENRAG_SERVICE_TOKEN. Intended
# for platform-managed deployments (saas / on_prem) where the platform
# issues a service token that identifies the admin user that must be
# pinned into the all_access role mapping. Default off.
#
# When this flag is true the corresponding call inside startup_tasks()
# is suppressed — bootstrap is the single source of truth on startup.
OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP = os.getenv(
"OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP", "false"
).lower() in ("true", "1", "yes")

# Enable FastAPI's `debug` mode (verbose tracebacks in HTTP error responses
# on the FastAPI app instance). Named explicitly so it isn't confused with
# logging-level "debug" or other unrelated debug flags.
Expand Down Expand Up @@ -1004,8 +1024,8 @@ async def _update_langflow_global_variable(self, name: str, value: str):
error=str(e),
)

def create_user_opensearch_client(self, jwt_token: str):
"""Create OpenSearch client with user's auth token.
def create_opensearch_client_from_jwt(self, jwt_token: str):
"""Create an OpenSearch client authenticated with a JWT bearer token.

If jwt_token already contains an auth scheme (e.g. "Basic ..." or "Bearer ..."),
it is used verbatim. Otherwise it is wrapped as a Bearer token.
Expand All @@ -1032,6 +1052,10 @@ def create_user_opensearch_client(self, jwt_token: str):
retry_on_timeout=True,
)

def create_user_opensearch_client(self, jwt_token: str):
"""Create OpenSearch client with user's auth token."""
return self.create_opensearch_client_from_jwt(jwt_token)


# Component template paths — derived from the centralized flows directory
def _component_path(env_var: str, filename: str) -> str:
Expand Down
9 changes: 8 additions & 1 deletion src/services/startup_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from config.settings import (
DISABLE_INGEST_WITH_LANGFLOW,
FETCH_OPENRAG_DOCS_AT_STARTUP,
OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP,
OPENRAG_SKIP_OS_SECURITY_SETUP,
clients,
get_openrag_config,
Expand Down Expand Up @@ -69,12 +70,18 @@ async def startup_tasks(services):
# Setup OpenSearch security (roles and mappings) after connection is established.
# Skip entirely when the platform manages the security context externally
# (SaaS / CPD): the call would otherwise either fail with 403/401 or
# overwrite a curated config.
# overwrite a curated config. Also skip when the lifespan-level
# bootstrap (driven by OPENRAG_SERVICE_TOKEN) has already handled it.
if OPENRAG_SKIP_OS_SECURITY_SETUP:
logger.info(
"Skipping OpenSearch security setup at startup "
"(OPENRAG_SKIP_OS_SECURITY_SETUP=true)"
)
elif OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP:
logger.info(
"Skipping OpenSearch security setup in startup_tasks "
"(handled by lifespan bootstrap)"
)
Comment on lines +80 to +84
else:
try:
from utils.opensearch_utils import setup_opensearch_security
Expand Down
150 changes: 150 additions & 0 deletions tests/unit/config/test_jwt_issuer_verification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import sys
import time
from pathlib import Path
from unittest.mock import MagicMock, patch

import jwt
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec

ROOT = Path(__file__).resolve().parent.parent.parent.parent
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))

from config import utils # noqa: E402


@pytest.fixture(autouse=True)
def clear_public_key_cache():
utils._ISSUER_PUBLIC_KEY_CACHE.clear()
yield
utils._ISSUER_PUBLIC_KEY_CACHE.clear()


def _make_es256_token(issuer: str) -> tuple[str, str]:
private_key = ec.generate_private_key(ec.SECP256R1())
public_pem = (
private_key.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode()
)
now = int(time.time())
token = jwt.encode(
{
"iss": issuer,
"sub": "system:serviceaccount:tenant:wxd-openrag-be",
"exp": now + 900,
"iat": now,
"roles": ["access_all"],
},
private_key,
algorithm="ES256",
headers={"typ": "JWT"},
)
return token, public_pem


def test_verify_jwt_from_issuer_fetches_public_key_and_validates_es256_token():
issuer = "https://authserver-oidc-svc.openrag-control.svc.cluster.local:8082/keys/workload"
token, public_pem = _make_es256_token(issuer)

response = MagicMock()
response.headers = {"content-type": "application/json"}
response.json.return_value = {"public_key": public_pem}

client = MagicMock()
client.__enter__.return_value = client
client.get.return_value = response

with patch("config.utils.httpx.Client", return_value=client):
claims = utils.verify_jwt_from_issuer(
f"Bearer {token}",
verify_tls=False,
)

assert claims is not None
assert claims["iss"] == issuer
assert claims["roles"] == ["access_all"]
client.get.assert_called_once_with(issuer)


def test_verify_jwt_from_issuer_accepts_standard_jwks_response():
issuer = "https://authserver-oidc-svc.openrag-control.svc.cluster.local:8082/keys/workload"
private_key = ec.generate_private_key(ec.SECP256R1())
public_numbers = private_key.public_key().public_numbers()

def _b64(value: int) -> str:
import base64

length = (value.bit_length() + 7) // 8
return base64.urlsafe_b64encode(value.to_bytes(length, "big")).rstrip(b"=").decode()

jwks = {
"keys": [
{
"alg": "ES256",
"crv": "P-256",
"kty": "EC",
"use": "sig",
"x": _b64(public_numbers.x),
"y": _b64(public_numbers.y),
}
]
}

now = int(time.time())
token = jwt.encode(
{
"iss": issuer,
"sub": "system:serviceaccount:tenant:wxd-openrag-be",
"exp": now + 900,
"iat": now,
},
private_key,
algorithm="ES256",
)

response = MagicMock()
response.headers = {"content-type": "application/json"}
response.json.return_value = jwks

client = MagicMock()
client.__enter__.return_value = client
client.get.return_value = response

with patch("config.utils.httpx.Client", return_value=client):
claims = utils.verify_jwt_from_issuer(
token,
verify_tls=False,
)

assert claims is not None
assert claims["iss"] == issuer


def test_verify_jwt_from_issuer_accepts_raw_pem_response():
issuer = "https://authserver-oidc-svc.openrag-control.svc.cluster.local:8082/keys/raw"
token, public_pem = _make_es256_token(issuer)

response = MagicMock()
response.headers = {"content-type": "application/x-pem-file"}
response.json.side_effect = ValueError("not json")
response.text = public_pem

client = MagicMock()
client.__enter__.return_value = client
client.get.return_value = response

with patch("config.utils.httpx.Client", return_value=client):
claims = utils.verify_jwt_from_issuer(
token,
verify_tls=False,
)

assert claims is not None
assert claims["iss"] == issuer
4 changes: 1 addition & 3 deletions tests/unit/test_jwt_claims_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
fixture so tests are fully isolated.
"""

import os
import sys
import time
from pathlib import Path
Expand All @@ -17,9 +18,6 @@
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))

# Patch heavy config imports before importing session_manager
import os

os.environ.setdefault("OPENRAG_JWT_CACHE_TTL", "60")
os.environ.setdefault("OPENRAG_JWT_CACHE_MAXSIZE", "1024")

Expand Down
Loading