From 3d9b6255f95817ddd21fabdbc1c5c8b020cecae0 Mon Sep 17 00:00:00 2001 From: Aidan Daly Date: Thu, 30 Apr 2026 18:19:55 -0400 Subject: [PATCH] feat: support on-behalf-of token exchange and additional parameters --- src/bedrock_agentcore/identity/auth.py | 10 +- src/bedrock_agentcore/services/identity.py | 12 +- tests/bedrock_agentcore/identity/test_auth.py | 56 ++++++ .../services/test_identity.py | 178 ++++++++++++++++++ 4 files changed, 252 insertions(+), 4 deletions(-) diff --git a/src/bedrock_agentcore/identity/auth.py b/src/bedrock_agentcore/identity/auth.py index 4bd99b35..447d9c24 100644 --- a/src/bedrock_agentcore/identity/auth.py +++ b/src/bedrock_agentcore/identity/auth.py @@ -24,8 +24,10 @@ def requires_access_token( provider_name: str, into: str = "access_token", scopes: List[str], + resources: Optional[List[str]] = None, + audiences: Optional[List[str]] = None, on_auth_url: Optional[Callable[[str], Any]] = None, - auth_flow: Literal["M2M", "USER_FEDERATION"], + auth_flow: Literal["M2M", "USER_FEDERATION", "ON_BEHALF_OF_TOKEN_EXCHANGE"], callback_url: Optional[str] = None, force_authentication: bool = False, token_poller: Optional[TokenPoller] = None, @@ -38,8 +40,10 @@ def requires_access_token( provider_name: The credential provider name into: Parameter name to inject the token into scopes: OAuth2 scopes to request + resources: OAuth2 resources to request + audiences: OAuth2 audiences to request on_auth_url: Callback for handling authorization URLs - auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION") + auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION" or "ON_BEHALF_OF_TOKEN_EXCHANGE") callback_url: OAuth2 callback URL force_authentication: Force re-authentication token_poller: Custom token poller implementation @@ -60,6 +64,8 @@ async def _get_token() -> str: provider_name=provider_name, agent_identity_token=await _get_workload_access_token(client), scopes=scopes, + resources=resources, + audiences=audiences, on_auth_url=on_auth_url, auth_flow=auth_flow, callback_url=_get_oauth2_callback_url(callback_url), diff --git a/src/bedrock_agentcore/services/identity.py b/src/bedrock_agentcore/services/identity.py index 0f022986..7bb21da9 100644 --- a/src/bedrock_agentcore/services/identity.py +++ b/src/bedrock_agentcore/services/identity.py @@ -202,9 +202,11 @@ async def get_token( *, provider_name: str, scopes: Optional[List[str]] = None, + resources: Optional[List[str]] = None, + audiences: Optional[List[str]] = None, agent_identity_token: str, on_auth_url: Optional[Callable[[str], Any]] = None, - auth_flow: Literal["M2M", "USER_FEDERATION"], + auth_flow: Literal["M2M", "USER_FEDERATION", "ON_BEHALF_OF_TOKEN_EXCHANGE"], callback_url: Optional[str] = None, force_authentication: bool = False, token_poller: Optional[TokenPoller] = None, @@ -216,9 +218,11 @@ async def get_token( Args: provider_name: The credential provider name scopes: Optional list of OAuth2 scopes to request + resources: Optional list of OAuth2 resources to request + audiences: Optional list of OAuth2 audiences to request agent_identity_token: Agent identity token for authentication on_auth_url: Callback for handling authorization URLs - auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION") + auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION" or "ON_BEHALF_OF_TOKEN_EXCHANGE") callback_url: OAuth2 callback URL (must be pre-registered) force_authentication: Force re-authentication even if token exists in the token vault token_poller: Custom token poller implementation @@ -244,6 +248,10 @@ async def get_token( } # Add optional parameters + if resources: + req["resources"] = resources + if audiences: + req["audiences"] = audiences if callback_url: req["resourceOauth2ReturnUrl"] = callback_url if force_authentication: diff --git a/tests/bedrock_agentcore/identity/test_auth.py b/tests/bedrock_agentcore/identity/test_auth.py index 5bda1dac..5d5053ab 100644 --- a/tests/bedrock_agentcore/identity/test_auth.py +++ b/tests/bedrock_agentcore/identity/test_auth.py @@ -50,6 +50,8 @@ async def test_async_func(param1, access_token=None): provider_name="test-provider", agent_identity_token="test-agent-token", scopes=["read", "write"], + resources=None, + audiences=None, on_auth_url=None, auth_flow="M2M", callback_url=None, @@ -163,6 +165,8 @@ async def test_func(param1, my_token=None): provider_name="test-provider", agent_identity_token="test-agent-token", scopes=["read"], + resources=None, + audiences=None, on_auth_url=None, auth_flow="M2M", callback_url=None, @@ -205,6 +209,8 @@ def on_auth_url(url): provider_name="test-provider", into="token", scopes=["read", "write"], + resources=["https://backend.example.com/api1", "https://backend.example.com/api2"], + audiences=["urn:example:cooperation"], on_auth_url=on_auth_url, auth_flow="USER_FEDERATION", callback_url="https://example.com/callback", @@ -223,6 +229,8 @@ async def test_func(token=None): provider_name="test-provider", agent_identity_token="test-agent-token", scopes=["read", "write"], + resources=["https://backend.example.com/api1", "https://backend.example.com/api2"], + audiences=["urn:example:cooperation"], on_auth_url=on_auth_url, auth_flow="USER_FEDERATION", callback_url="https://example.com/callback", @@ -250,6 +258,8 @@ async def test_custom_parameters_passed_to_client(self): @requires_access_token( provider_name="test-provider", scopes=["read"], + resources=None, + audiences=None, auth_flow="USER_FEDERATION", custom_parameters=custom_params, ) @@ -263,6 +273,8 @@ async def test_func(access_token=None): provider_name="test-provider", agent_identity_token="test-agent-token", scopes=["read"], + resources=None, + audiences=None, auth_flow="USER_FEDERATION", callback_url=None, force_authentication=False, @@ -272,6 +284,50 @@ async def test_func(access_token=None): custom_parameters=custom_params, ) + @pytest.mark.asyncio + @pytest.mark.parametrize("auth_flow", ["M2M", "USER_FEDERATION", "ON_BEHALF_OF_TOKEN_EXCHANGE"]) + async def test_all_auth_flows(self, auth_flow): + """Test decorator forwards each supported auth_flow value to the identity client.""" + with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + mock_client = Mock() + mock_identity_client_class.return_value = mock_client + + with patch( + "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + ) as mock_get_agent_token: + mock_get_agent_token.return_value = "test-agent-token" + mock_client.get_token = AsyncMock(return_value="test-access-token") + + with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + + @requires_access_token( + provider_name="test-provider", + scopes=["read"], + resources=["https://backend.example.com/api"], + audiences=["urn:example:audience"], + auth_flow=auth_flow, + ) + async def test_func(access_token=None): + return access_token + + result = await test_func() + + assert result == "test-access-token" + mock_client.get_token.assert_called_once_with( + provider_name="test-provider", + agent_identity_token="test-agent-token", + scopes=["read"], + resources=["https://backend.example.com/api"], + audiences=["urn:example:audience"], + on_auth_url=None, + auth_flow=auth_flow, + callback_url=None, + force_authentication=False, + token_poller=None, + custom_state=None, + custom_parameters=None, + ) + class TestRequiresIamAccessTokenDecorator: """Test the requires_iam_access_token decorator.""" diff --git a/tests/bedrock_agentcore/services/test_identity.py b/tests/bedrock_agentcore/services/test_identity.py index edcfc210..58479e48 100644 --- a/tests/bedrock_agentcore/services/test_identity.py +++ b/tests/bedrock_agentcore/services/test_identity.py @@ -348,6 +348,184 @@ async def test_get_token_with_custom_parameters(self): customParameters=custom_parameters, ) + @pytest.mark.asyncio + async def test_get_token_with_resources(self): + """Test get_token forwards the resources list to the data plane request.""" + region = "us-west-2" + + with patch("boto3.client") as mock_boto_client: + mock_client = Mock() + mock_boto_client.return_value = mock_client + + identity_client = IdentityClient(region) + + provider_name = "test-provider" + scopes = ["read"] + agent_identity_token = "test-agent-token" + resources = ["https://backend.example.com/api1", "https://backend.example.com/api2"] + expected_token = "test-access-token" + + mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + + result = await identity_client.get_token( + provider_name=provider_name, + scopes=scopes, + resources=resources, + agent_identity_token=agent_identity_token, + auth_flow="M2M", + ) + + assert result == expected_token + mock_client.get_resource_oauth2_token.assert_called_once_with( + resourceCredentialProviderName=provider_name, + scopes=scopes, + oauth2Flow="M2M", + workloadIdentityToken=agent_identity_token, + resources=resources, + ) + + @pytest.mark.asyncio + async def test_get_token_with_audiences(self): + """Test get_token forwards the audiences list to the data plane request.""" + region = "us-west-2" + + with patch("boto3.client") as mock_boto_client: + mock_client = Mock() + mock_boto_client.return_value = mock_client + + identity_client = IdentityClient(region) + + provider_name = "test-provider" + scopes = ["read"] + agent_identity_token = "test-agent-token" + audiences = ["urn:example:cooperation", "urn:example:other"] + expected_token = "test-access-token" + + mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + + result = await identity_client.get_token( + provider_name=provider_name, + scopes=scopes, + audiences=audiences, + agent_identity_token=agent_identity_token, + auth_flow="M2M", + ) + + assert result == expected_token + mock_client.get_resource_oauth2_token.assert_called_once_with( + resourceCredentialProviderName=provider_name, + scopes=scopes, + oauth2Flow="M2M", + workloadIdentityToken=agent_identity_token, + audiences=audiences, + ) + + @pytest.mark.asyncio + async def test_get_token_with_resources_and_audiences(self): + """Test get_token forwards both resources and audiences together.""" + region = "us-west-2" + + with patch("boto3.client") as mock_boto_client: + mock_client = Mock() + mock_boto_client.return_value = mock_client + + identity_client = IdentityClient(region) + + provider_name = "test-provider" + scopes = ["read", "write"] + agent_identity_token = "test-agent-token" + resources = ["https://backend.example.com/api"] + audiences = ["urn:example:cooperation"] + expected_token = "test-access-token" + + mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + + result = await identity_client.get_token( + provider_name=provider_name, + scopes=scopes, + resources=resources, + audiences=audiences, + agent_identity_token=agent_identity_token, + auth_flow="ON_BEHALF_OF_TOKEN_EXCHANGE", + ) + + assert result == expected_token + mock_client.get_resource_oauth2_token.assert_called_once_with( + resourceCredentialProviderName=provider_name, + scopes=scopes, + oauth2Flow="ON_BEHALF_OF_TOKEN_EXCHANGE", + workloadIdentityToken=agent_identity_token, + resources=resources, + audiences=audiences, + ) + + @pytest.mark.asyncio + async def test_get_token_empty_resources_and_audiences_not_forwarded(self): + """Test get_token does not forward empty/None resources or audiences.""" + region = "us-west-2" + + with patch("boto3.client") as mock_boto_client: + mock_client = Mock() + mock_boto_client.return_value = mock_client + + identity_client = IdentityClient(region) + + provider_name = "test-provider" + scopes = ["read"] + agent_identity_token = "test-agent-token" + expected_token = "test-access-token" + + mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + + # Pass empty list and None — neither should appear in the request + result = await identity_client.get_token( + provider_name=provider_name, + scopes=scopes, + resources=[], + audiences=None, + agent_identity_token=agent_identity_token, + auth_flow="M2M", + ) + + assert result == expected_token + call_kwargs = mock_client.get_resource_oauth2_token.call_args.kwargs + assert "resources" not in call_kwargs + assert "audiences" not in call_kwargs + + @pytest.mark.asyncio + @pytest.mark.parametrize("auth_flow", ["M2M", "USER_FEDERATION", "ON_BEHALF_OF_TOKEN_EXCHANGE"]) + async def test_get_token_all_auth_flows(self, auth_flow): + """Test get_token forwards each supported auth_flow to the data plane request.""" + region = "us-west-2" + + with patch("boto3.client") as mock_boto_client: + mock_client = Mock() + mock_boto_client.return_value = mock_client + + identity_client = IdentityClient(region) + + provider_name = "test-provider" + scopes = ["read"] + agent_identity_token = "test-agent-token" + expected_token = "test-access-token" + + mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + + result = await identity_client.get_token( + provider_name=provider_name, + scopes=scopes, + agent_identity_token=agent_identity_token, + auth_flow=auth_flow, + ) + + assert result == expected_token + mock_client.get_resource_oauth2_token.assert_called_once_with( + resourceCredentialProviderName=provider_name, + scopes=scopes, + oauth2Flow=auth_flow, + workloadIdentityToken=agent_identity_token, + ) + @pytest.mark.asyncio async def test_get_api_key_success(self): """Test successful API key retrieval."""