diff --git a/src/app/lifespan.py b/src/app/lifespan.py index 48b03e0f6..7be8e1fc8 100644 --- a/src/app/lifespan.py +++ b/src/app/lifespan.py @@ -14,11 +14,13 @@ from config.settings import ( JWT_CLAIMS_CACHE_MAX_SIZE, JWT_CLAIMS_CACHE_TTL_SECONDS, + OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP, RBAC_CACHE_BACKEND, RBAC_PERMISSION_CACHE_TTL_SECONDS, UVICORN_WORKER_COUNT, clients, get_openrag_config, + get_openrag_service_token, ) from services.startup_orchestrator import startup_tasks from utils.logging_config import get_logger @@ -199,6 +201,37 @@ async def run_startup(app: FastAPI): await mcp_lifespan_ctx.__aenter__() logger.info("FastMCP lifespan started") + # One-shot OpenSearch security bootstrap driven by the platform's + # service JWT. Runs synchronously (before startup_tasks) so the + # admin role mapping is in place before any other startup work + # talks to OpenSearch. The corresponding call inside startup_tasks + # is suppressed when this flag is on. + if OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP: + service_token = get_openrag_service_token() + if not service_token: + raise RuntimeError( + "OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP is enabled but " + "OPENRAG_SERVICE_TOKEN is not set" + ) + from auth.ibm_auth import admin_username_from_service_jwt + from utils.opensearch_init import wait_for_opensearch + from utils.opensearch_utils import setup_opensearch_security + + admin_username = admin_username_from_service_jwt(service_token) + if not admin_username: + raise RuntimeError( + "OPENRAG_SERVICE_TOKEN has no 'username' or 'sub' claim; " + "cannot bootstrap OpenSearch security" + ) + opensearch_client = clients.create_opensearch_client_from_jwt(service_token) + try: + await wait_for_opensearch(opensearch_client) + logger.info("Bootstrapping OpenSearch security", admin_username=admin_username) + await setup_opensearch_security(opensearch_client, admin_username=admin_username) + logger.info("OpenSearch security bootstrap completed", admin_username=admin_username) + finally: + await opensearch_client.close() + # Start index initialization in background to avoid blocking OIDC endpoints t1 = asyncio.create_task(startup_tasks(services)) app.state.background_tasks.add(t1) diff --git a/src/auth/ibm_auth.py b/src/auth/ibm_auth.py index f6af688eb..ca5d58c63 100644 --- a/src/auth/ibm_auth.py +++ b/src/auth/ibm_auth.py @@ -52,6 +52,30 @@ def decode_ibm_jwt(token: str) -> dict | None: return None +def admin_username_from_service_jwt(token: str) -> str | None: + """Return the admin username carried by a platform-issued service JWT. + + Decodes *token* unsigned (the platform issues it; we only parse claims) + and returns `username` if present, falling back to `sub`. Matches the + claim precedence used by the auth dependency in dependencies.py. + Returns None if the token cannot be decoded or has neither claim. + """ + try: + claims = jwt.decode(token, options={"verify_signature": False}) + except jwt.InvalidTokenError as exc: + logger.warning("Service JWT decode failed", error=str(exc)) + return None + value = claims.get("username") or claims.get("sub") + if not isinstance(value, str): + if value is not None: + logger.warning( + "Service JWT username/sub claim is not a string", + claim_type=type(value).__name__, + ) + return None + return value + + async def fetch_ibm_public_key(url: str): """Fetch IBM's JWT public key PEM from *url* and cache it.""" global _cached_public_key diff --git a/src/config/settings.py b/src/config/settings.py index f014d71df..d4227310b 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -67,6 +67,7 @@ IBM_AUTH_ENABLED = os.getenv("IBM_AUTH_ENABLED", "false").lower() in ("true", "1", "yes") PLATFORM_USERNAME = os.getenv("PLATFORM_USERNAME") PLATFORM_PASSWORD = os.getenv("PLATFORM_PASSWORD") +OPENRAG_TENANT_ID = os.getenv("OPENRAG_TENANT_ID", "openrag") IBM_JWT_PUBLIC_KEY_URL = os.getenv("IBM_JWT_PUBLIC_KEY_URL", "") IBM_SESSION_COOKIE_NAME = os.getenv("IBM_SESSION_COOKIE_NAME", "ibm-openrag-session") IBM_CREDENTIALS_HEADER = os.getenv("IBM_CREDENTIALS_HEADER", "X-IBM-LH-Credentials") @@ -107,6 +108,13 @@ def get_role_claim_viewer() -> str | None: return os.getenv("OPENRAG_ROLE_CLAIM_VIEWER") +def get_openrag_service_token() -> str | None: + """Platform-issued service JWT used at startup to bootstrap the OpenSearch + security context (admin role mapping). Read per-call — like the JWT-claim + accessors above — so runtime/test overrides take effect without a restart.""" + return os.getenv("OPENRAG_SERVICE_TOKEN") + + def get_jwt_auth_header() -> str: """HTTP header that may carry a gateway-forwarded JWT for /v1 (API-key) callers. Read per-call so tests can override via monkeypatch.setenv.""" @@ -149,6 +157,18 @@ def _resolve_skip_os_security_default() -> str: "OPENRAG_SKIP_OS_SECURITY_SETUP", _resolve_skip_os_security_default() ).lower() in ("true", "1", "yes") +# Run setup_opensearch_security once during FastAPI lifespan startup, +# using the admin username derived from OPENRAG_SERVICE_TOKEN. Intended +# for platform-managed deployments (saas / on_prem) where the platform +# issues a service token that identifies the admin user that must be +# pinned into the all_access role mapping. Default off. +# +# When this flag is true the corresponding call inside startup_tasks() +# is suppressed — bootstrap is the single source of truth on startup. +OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP = os.getenv( + "OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP", "false" +).lower() in ("true", "1", "yes") + # Enable FastAPI's `debug` mode (verbose tracebacks in HTTP error responses # on the FastAPI app instance). Named explicitly so it isn't confused with # logging-level "debug" or other unrelated debug flags. @@ -1004,8 +1024,8 @@ async def _update_langflow_global_variable(self, name: str, value: str): error=str(e), ) - def create_user_opensearch_client(self, jwt_token: str): - """Create OpenSearch client with user's auth token. + def create_opensearch_client_from_jwt(self, jwt_token: str): + """Create an OpenSearch client authenticated with a JWT bearer token. If jwt_token already contains an auth scheme (e.g. "Basic ..." or "Bearer ..."), it is used verbatim. Otherwise it is wrapped as a Bearer token. @@ -1032,6 +1052,10 @@ def create_user_opensearch_client(self, jwt_token: str): retry_on_timeout=True, ) + def create_user_opensearch_client(self, jwt_token: str): + """Create OpenSearch client with user's auth token.""" + return self.create_opensearch_client_from_jwt(jwt_token) + # Component template paths — derived from the centralized flows directory def _component_path(env_var: str, filename: str) -> str: diff --git a/src/services/startup_orchestrator.py b/src/services/startup_orchestrator.py index f6d7646fd..6a3a94ad8 100644 --- a/src/services/startup_orchestrator.py +++ b/src/services/startup_orchestrator.py @@ -9,6 +9,7 @@ from config.settings import ( DISABLE_INGEST_WITH_LANGFLOW, FETCH_OPENRAG_DOCS_AT_STARTUP, + OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP, OPENRAG_SKIP_OS_SECURITY_SETUP, clients, get_openrag_config, @@ -69,12 +70,18 @@ async def startup_tasks(services): # Setup OpenSearch security (roles and mappings) after connection is established. # Skip entirely when the platform manages the security context externally # (SaaS / CPD): the call would otherwise either fail with 403/401 or - # overwrite a curated config. + # overwrite a curated config. Also skip when the lifespan-level + # bootstrap (driven by OPENRAG_SERVICE_TOKEN) has already handled it. if OPENRAG_SKIP_OS_SECURITY_SETUP: logger.info( "Skipping OpenSearch security setup at startup " "(OPENRAG_SKIP_OS_SECURITY_SETUP=true)" ) + elif OPENRAG_BOOTSTRAP_OS_SECURITY_ON_STARTUP: + logger.info( + "Skipping OpenSearch security setup in startup_tasks " + "(handled by lifespan bootstrap)" + ) else: try: from utils.opensearch_utils import setup_opensearch_security diff --git a/tests/unit/config/test_jwt_issuer_verification.py b/tests/unit/config/test_jwt_issuer_verification.py new file mode 100644 index 000000000..7c35e5b7f --- /dev/null +++ b/tests/unit/config/test_jwt_issuer_verification.py @@ -0,0 +1,150 @@ +import sys +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import jwt +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec + +ROOT = Path(__file__).resolve().parent.parent.parent.parent +SRC = ROOT / "src" +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) + +from config import utils # noqa: E402 + + +@pytest.fixture(autouse=True) +def clear_public_key_cache(): + utils._ISSUER_PUBLIC_KEY_CACHE.clear() + yield + utils._ISSUER_PUBLIC_KEY_CACHE.clear() + + +def _make_es256_token(issuer: str) -> tuple[str, str]: + private_key = ec.generate_private_key(ec.SECP256R1()) + public_pem = ( + private_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode() + ) + now = int(time.time()) + token = jwt.encode( + { + "iss": issuer, + "sub": "system:serviceaccount:tenant:wxd-openrag-be", + "exp": now + 900, + "iat": now, + "roles": ["access_all"], + }, + private_key, + algorithm="ES256", + headers={"typ": "JWT"}, + ) + return token, public_pem + + +def test_verify_jwt_from_issuer_fetches_public_key_and_validates_es256_token(): + issuer = "https://authserver-oidc-svc.openrag-control.svc.cluster.local:8082/keys/workload" + token, public_pem = _make_es256_token(issuer) + + response = MagicMock() + response.headers = {"content-type": "application/json"} + response.json.return_value = {"public_key": public_pem} + + client = MagicMock() + client.__enter__.return_value = client + client.get.return_value = response + + with patch("config.utils.httpx.Client", return_value=client): + claims = utils.verify_jwt_from_issuer( + f"Bearer {token}", + verify_tls=False, + ) + + assert claims is not None + assert claims["iss"] == issuer + assert claims["roles"] == ["access_all"] + client.get.assert_called_once_with(issuer) + + +def test_verify_jwt_from_issuer_accepts_standard_jwks_response(): + issuer = "https://authserver-oidc-svc.openrag-control.svc.cluster.local:8082/keys/workload" + private_key = ec.generate_private_key(ec.SECP256R1()) + public_numbers = private_key.public_key().public_numbers() + + def _b64(value: int) -> str: + import base64 + + length = (value.bit_length() + 7) // 8 + return base64.urlsafe_b64encode(value.to_bytes(length, "big")).rstrip(b"=").decode() + + jwks = { + "keys": [ + { + "alg": "ES256", + "crv": "P-256", + "kty": "EC", + "use": "sig", + "x": _b64(public_numbers.x), + "y": _b64(public_numbers.y), + } + ] + } + + now = int(time.time()) + token = jwt.encode( + { + "iss": issuer, + "sub": "system:serviceaccount:tenant:wxd-openrag-be", + "exp": now + 900, + "iat": now, + }, + private_key, + algorithm="ES256", + ) + + response = MagicMock() + response.headers = {"content-type": "application/json"} + response.json.return_value = jwks + + client = MagicMock() + client.__enter__.return_value = client + client.get.return_value = response + + with patch("config.utils.httpx.Client", return_value=client): + claims = utils.verify_jwt_from_issuer( + token, + verify_tls=False, + ) + + assert claims is not None + assert claims["iss"] == issuer + + +def test_verify_jwt_from_issuer_accepts_raw_pem_response(): + issuer = "https://authserver-oidc-svc.openrag-control.svc.cluster.local:8082/keys/raw" + token, public_pem = _make_es256_token(issuer) + + response = MagicMock() + response.headers = {"content-type": "application/x-pem-file"} + response.json.side_effect = ValueError("not json") + response.text = public_pem + + client = MagicMock() + client.__enter__.return_value = client + client.get.return_value = response + + with patch("config.utils.httpx.Client", return_value=client): + claims = utils.verify_jwt_from_issuer( + token, + verify_tls=False, + ) + + assert claims is not None + assert claims["iss"] == issuer diff --git a/tests/unit/test_jwt_claims_cache.py b/tests/unit/test_jwt_claims_cache.py index fe70b2929..2dce2cf6f 100644 --- a/tests/unit/test_jwt_claims_cache.py +++ b/tests/unit/test_jwt_claims_cache.py @@ -5,6 +5,7 @@ fixture so tests are fully isolated. """ +import os import sys import time from pathlib import Path @@ -17,9 +18,6 @@ if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) -# Patch heavy config imports before importing session_manager -import os - os.environ.setdefault("OPENRAG_JWT_CACHE_TTL", "60") os.environ.setdefault("OPENRAG_JWT_CACHE_MAXSIZE", "1024")