Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions hub_adapter/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,22 @@ async def verify_idp_token(
) from Exception


async def _get_internal_token(oidc_config, settings: Annotated[Settings, Depends(get_settings)]) -> dict | None:
"""If the Hub Adapter is set up to use an external IDP, it needs to retrieve a JWT from the internal keycloak
to make requests to the PO."""
async def _get_internal_token(settings: Annotated[Settings, Depends(get_settings)]) -> dict | None:
"""If the Hub Adapter is set up to use an external IDP for user auth, it needs to retrieve a JWT from the
internal keycloak to make requests to the PO."""

payload = {
"grant_type": "client_credentials",
"client_id": settings.api_client_id,
"client_secret": settings.api_client_secret,
}

with httpx.Client(verify=get_ssl_context(settings)) as client:
resp = client.post(oidc_config.token_endpoint, data=payload)
resp.raise_for_status()
token_data = resp.json()
svc_oidc_config = get_svc_oidc_config()
int_token_ep = svc_oidc_config.token_endpoint

resp = httpx.post(int_token_ep, data=payload)
resp.raise_for_status()
token_data = resp.json()

token = Token(**token_data)
return {"Authorization": f"Bearer {token.access_token}"}
Expand All @@ -182,11 +184,11 @@ async def _get_internal_token(oidc_config, settings: Annotated[Settings, Depends
async def _add_internal_token_if_missing(request: Request) -> Request:
"""Adds a JWT from the internal IDP is not present in the request."""
settings = get_settings()
configs_match, oidc_config = check_oidc_configs_match()
configs_match, _ = check_oidc_configs_match()

if not configs_match:
logger.debug("External IDP different from internal, retrieving JWT from internal keycloak")
internal_token = await _get_internal_token(oidc_config, settings)
internal_token = await _get_internal_token(settings)
if internal_token:
updated_headers = MutableHeaders(request.headers)
updated_headers.update(internal_token)
Expand Down
3 changes: 1 addition & 2 deletions hub_adapter/autostart.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ async def pod_running(self, analysis_id: str) -> bool | None:
async def fetch_token_header(self) -> dict | None:
"""Append OIDC token to headers."""
try:
_, oidc_config = check_oidc_configs_match()
token = await _get_internal_token(oidc_config, self.settings)
token = await _get_internal_token(self.settings)
return token

except (HTTPException, HTTPStatusError) as e:
Expand Down
3 changes: 1 addition & 2 deletions hub_adapter/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ async def get_auth_token(self) -> dict:
if self.token:
return self.token

oidc_config = get_svc_oidc_config()
self.token = await _get_internal_token(oidc_config, self.settings)
self.token = await _get_internal_token(self.settings)
logger.info("Successfully obtained authentication token")
return self.token

Expand Down
3 changes: 2 additions & 1 deletion hub_adapter/routers/kong.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import uuid
from typing import Annotated
from uuid import UUID

import httpx
import kong_admin_client
Expand Down Expand Up @@ -687,7 +688,7 @@ async def create_and_connect_analysis_to_project(
@catch_kong_errors
async def delete_analysis(
settings: Annotated[Settings, Depends(get_settings)],
analysis_id: Annotated[str, Path(description="UUID or unique name of the analysis.")],
analysis_id: Annotated[str | UUID, Path(description="UUID or unique name of the analysis.")],
):
"""Delete the listed analysis."""
configuration = kong_admin_client.Configuration(host=settings.kong_admin_service_url)
Expand Down
3 changes: 1 addition & 2 deletions hub_adapter/routers/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ async def terminate_analysis(
"""
await delete_analysis(analysis_id=analysis_id, settings=settings)

configs_match, oidc_config = check_oidc_configs_match()
headers = await _get_internal_token(oidc_config, settings)
headers = await _get_internal_token(settings)

microsvc_path = f"{settings.podorc_service_url}/po/delete/{analysis_id}"

Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tests.constants import (
FAKE_USER,
TEST_MOCK_NODE_CLIENT_ID,
TEST_SVC_URL,
TEST_URL,
)

Expand Down Expand Up @@ -75,7 +76,7 @@ def test_settings() -> Settings:
api_client_secret="notASecret",
http_proxy="http://squid.proxy:3128",
https_proxy="http://squid.proxy:3128",
node_svc_oidc_url=TEST_URL,
node_svc_oidc_url=TEST_SVC_URL,
postgres_event_db="test_db",
postgres_event_user="test_user",
postgres_event_password="test_password",
Expand Down
11 changes: 4 additions & 7 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@
TEST_OIDC = OIDCConfiguration(
issuer=TEST_URL,
authorization_endpoint=TEST_URL,
token_endpoint=TEST_URL,
token_endpoint=f"{TEST_URL}/protocol/openid-connect/token",
jwks_uri=TEST_URL,
userinfo_endpoint=TEST_URL,
)
TEST_SVC_URL = "https://service.example"
TEST_SVC_OIDC = OIDCConfiguration(
issuer=TEST_SVC_URL,
authorization_endpoint=TEST_SVC_URL,
token_endpoint=TEST_SVC_URL,
token_endpoint=f"{TEST_SVC_URL}/protocol/openid-connect/token",
jwks_uri=TEST_SVC_URL,
userinfo_endpoint=TEST_SVC_URL,
)


TEST_MOCK_ANALYSIS_ID = "1c9cb547-4afc-4398-bcb6-954bc61a1bb1"
TEST_MOCK_PROJECT_ID = "9cbefefe-2420-4b8e-8ac1-f48148a9fd40"
TEST_MOCK_NODE_ID = "9c521144-364d-4cdc-8ec4-cb62a537f10c"
Expand Down Expand Up @@ -102,7 +101,6 @@
"analysis": MOCK_ANALYSIS,
}


ANALYSIS_NODES_RESP = [
{
# Shouldn't start because executed
Expand Down Expand Up @@ -208,15 +206,15 @@
"authorization_endpoint": TEST_URL,
"issuer": TEST_URL,
"jwks_uri": TEST_URL,
"token_endpoint": TEST_URL,
"token_endpoint": f"{TEST_URL}/protocol/openid-connect/token",
"userinfo_endpoint": TEST_URL,
}

TEST_OIDC_SVC_RESPONSE = {
"authorization_endpoint": TEST_SVC_URL,
"issuer": TEST_SVC_URL,
"jwks_uri": TEST_SVC_URL,
"token_endpoint": TEST_SVC_URL,
"token_endpoint": f"{TEST_SVC_URL}/protocol/openid-connect/token",
"userinfo_endpoint": TEST_SVC_URL,
}

Expand Down Expand Up @@ -320,7 +318,6 @@
"acl": ACL().__dict__,
}


FAKE_USER = {
"acr": "1",
"allowed-origins": ["/*"],
Expand Down
9 changes: 6 additions & 3 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TEST_OIDC,
TEST_RESEARCHER_DECRYPTED_JWT,
TEST_STEWARD_DECRYPTED_JWT,
TEST_SVC_OIDC,
)


Expand Down Expand Up @@ -89,18 +90,20 @@ async def test_verify_idp_token_errors(self, mock_decode, mock_user_oidc, mock_s
assert random_error.value.status_code == status.HTTP_401_UNAUTHORIZED
assert missing_claim_error.value.detail["message"] == "Unable to parse authentication token"

@patch("hub_adapter.auth.get_svc_oidc_config")
@pytest.mark.asyncio
async def test_get_internal_token(self, httpx_mock, test_settings):
async def test_get_internal_token(self, mock_svc_oidc, httpx_mock, test_settings):
"""Test the get_internal_token method."""
mock_svc_oidc.return_value = TEST_SVC_OIDC
fake_token_resp = {
"access_token": TEST_JWT,
"token_type": "Bearer",
"expires_in": 7200,
"refresh_token": TEST_JWT,
"refresh_expires_in": 1800,
}
httpx_mock.add_response(url=TEST_OIDC.token_endpoint, json=fake_token_resp, status_code=200)
assert await _get_internal_token(TEST_OIDC, test_settings) == {"Authorization": f"Bearer {TEST_JWT}"}
httpx_mock.add_response(url=TEST_SVC_OIDC.token_endpoint, json=fake_token_resp, status_code=200)
assert await _get_internal_token(test_settings) == {"Authorization": f"Bearer {TEST_JWT}"}

@patch("hub_adapter.auth._get_internal_token")
@patch("hub_adapter.auth.check_oidc_configs_match")
Expand Down
23 changes: 12 additions & 11 deletions tests/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,25 @@ def test_basic_oidc_fetching(self, mock_settings, httpx_mock, test_settings):

fake_oidc_url = f"{TEST_URL}/.well-known/openid-configuration"
fake_oidc_svc_url = f"{TEST_SVC_URL}/.well-known/openid-configuration"
mock_settings.return_value = test_settings
mock_settings.return_value = test_settings # Initializes with different URLs
httpx_mock.add_response(url=fake_oidc_url, json=TEST_OIDC_RESPONSE, status_code=200)

# Same OIDC
match_check, match_config = check_oidc_configs_match()
assert match_check
assert match_config == TEST_OIDC

# Different OIDC URLs
different_oidc_settings = test_settings.model_copy(update={"node_svc_oidc_url": TEST_SVC_URL})
mock_settings.return_value = different_oidc_settings

httpx_mock.add_response(url=fake_oidc_svc_url, json=TEST_OIDC_SVC_RESPONSE, status_code=200)

# Different OIDC URLs
diff_check, diff_config = check_oidc_configs_match()
assert not diff_check
assert diff_config == TEST_SVC_OIDC

# Same OIDC
matching_oidc_settings = test_settings.model_copy(update={"node_svc_oidc_url": fake_oidc_url})
mock_settings.return_value = matching_oidc_settings

httpx_mock.add_response(url=fake_oidc_url, json=TEST_OIDC_RESPONSE, status_code=200)

match_check, match_config = check_oidc_configs_match()
assert match_check
assert match_config == TEST_OIDC

@patch("hub_adapter.oidc.logger")
def test_fetch_openid_config_errors(self, mock_logger, httpx_mock):
"""Test the fetch_openid_config method for error handling."""
Expand Down
Loading