Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions rag-service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,14 @@ def internal_token_valid(provided: str | None, expected: str) -> bool:
return bool(expected) and bool(candidate) and secrets.compare_digest(candidate, expected)


def require_internal_rag_token_configured():
if not INTERNAL_RAG_TOKEN:
raise RuntimeError("INTERNAL_RAG_TOKEN must be configured for protected endpoints.")
def require_internal_rag_token_configured() -> bool:
return bool(INTERNAL_RAG_TOKEN)


require_internal_rag_token_configured()
if not require_internal_rag_token_configured():
logger.warning(
"INTERNAL_RAG_TOKEN is not configured; protected endpoints will return 503 until it is set."
)


# How often the background flush thread wakes and writes dirty session metadata
Expand Down Expand Up @@ -405,6 +407,14 @@ async def internal_auth_middleware(request: Request, call_next):
path in PROTECTED_RAG_PATHS
or any(path.startswith(prefix) for prefix in PROTECTED_RAG_PREFIXES)
):
if not INTERNAL_RAG_TOKEN:
logger.warning(
"Protected endpoint unavailable path=%s ip=%s reason=INTERNAL_RAG_TOKEN is not configured",
raw_path,
request.client.host if request.client else "unknown",
)
return standard_error_response(503, "INTERNAL_RAG_TOKEN is not configured")

provided = request.headers.get("X-Internal-Token")
if not internal_token_valid(provided, INTERNAL_RAG_TOKEN):
logger.warning(
Expand Down
24 changes: 9 additions & 15 deletions rag-service/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,15 @@ def test_require_internal_token_config_fails_when_unset(monkeypatch):

monkeypatch.setattr(main_module, "INTERNAL_RAG_TOKEN", "")

with pytest.raises(RuntimeError, match="INTERNAL_RAG_TOKEN"):
require_internal_rag_token_configured()
assert require_internal_rag_token_configured() is False


def test_internal_token_validation_passes_when_configured(monkeypatch):
import main as main_module

monkeypatch.setattr(main_module, "INTERNAL_RAG_TOKEN", "configured-secret")

assert require_internal_rag_token_configured() is None
assert require_internal_rag_token_configured() is True


def test_internal_auth_middleware_protects_validate_session_write():
Expand Down Expand Up @@ -710,28 +709,23 @@ def test_ask_stream_passes_middleware_with_correct_token():
finally:
main_module.INTERNAL_RAG_TOKEN = original

def test_ask_stream_rejected_when_token_is_cleared_after_startup():
@pytest.mark.parametrize("path", ["/process-pdf", "/ask", "/summarize"])
def test_protected_endpoints_rejected_when_token_is_cleared_after_startup(path):
"""Protected endpoints fail closed if token config becomes unavailable."""
import main as main_module

original = main_module.INTERNAL_RAG_TOKEN
main_module.INTERNAL_RAG_TOKEN = ""
try:
client = TestClient(app, raise_server_exceptions=False)
response = client.post(
"/ask/stream",
json={
"question": "What is this document about?",
"session_id": "00000000-0000-0000-0000-000000000004",
"session_secret": "irrelevant",
},
)
response = client.post(path)
# Fail-closed behavior: when INTERNAL_RAG_TOKEN is unset, the
# middleware should still block protected requests with 403.
assert response.status_code == 403, (
"Middleware must block protected requests when INTERNAL_RAG_TOKEN is unset. "
# middleware should reject protected requests with 503.
assert response.status_code == 503, (
"Middleware must reject protected requests when INTERNAL_RAG_TOKEN is unset. "
f"Got {response.status_code}"
)
assert response.json()["detail"] == "INTERNAL_RAG_TOKEN is not configured"
finally:
main_module.INTERNAL_RAG_TOKEN = original

Expand Down
Loading