Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/bedrock_agentcore/identity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ def requires_access_token(
provider_name: str,
into: str = "access_token",
scopes: List[str],
resources: Optional[List[str]] = None,
audiences: Optional[List[str]] = None,
on_auth_url: Optional[Callable[[str], Any]] = None,
auth_flow: Literal["M2M", "USER_FEDERATION"],
auth_flow: Literal["M2M", "USER_FEDERATION", "ON_BEHALF_OF_TOKEN_EXCHANGE"],
callback_url: Optional[str] = None,
force_authentication: bool = False,
token_poller: Optional[TokenPoller] = None,
Expand All @@ -38,8 +40,10 @@ def requires_access_token(
provider_name: The credential provider name
into: Parameter name to inject the token into
scopes: OAuth2 scopes to request
resources: OAuth2 resources to request
audiences: OAuth2 audiences to request
on_auth_url: Callback for handling authorization URLs
auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION")
auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION" or "ON_BEHALF_OF_TOKEN_EXCHANGE")
callback_url: OAuth2 callback URL
force_authentication: Force re-authentication
token_poller: Custom token poller implementation
Expand All @@ -60,6 +64,8 @@ async def _get_token() -> str:
provider_name=provider_name,
agent_identity_token=await _get_workload_access_token(client),
scopes=scopes,
resources=resources,
audiences=audiences,
on_auth_url=on_auth_url,
auth_flow=auth_flow,
callback_url=_get_oauth2_callback_url(callback_url),
Expand Down
12 changes: 10 additions & 2 deletions src/bedrock_agentcore/services/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,11 @@ async def get_token(
*,
provider_name: str,
scopes: Optional[List[str]] = None,
resources: Optional[List[str]] = None,
audiences: Optional[List[str]] = None,
agent_identity_token: str,
on_auth_url: Optional[Callable[[str], Any]] = None,
auth_flow: Literal["M2M", "USER_FEDERATION"],
auth_flow: Literal["M2M", "USER_FEDERATION", "ON_BEHALF_OF_TOKEN_EXCHANGE"],
callback_url: Optional[str] = None,
force_authentication: bool = False,
token_poller: Optional[TokenPoller] = None,
Expand All @@ -216,9 +218,11 @@ async def get_token(
Args:
provider_name: The credential provider name
scopes: Optional list of OAuth2 scopes to request
resources: Optional list of OAuth2 resources to request
audiences: Optional list of OAuth2 audiences to request
agent_identity_token: Agent identity token for authentication
on_auth_url: Callback for handling authorization URLs
auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION")
auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION" or "ON_BEHALF_OF_TOKEN_EXCHANGE")
callback_url: OAuth2 callback URL (must be pre-registered)
force_authentication: Force re-authentication even if token exists in the token vault
token_poller: Custom token poller implementation
Expand All @@ -244,6 +248,10 @@ async def get_token(
}

# Add optional parameters
if resources:
req["resources"] = resources
if audiences:
req["audiences"] = audiences
if callback_url:
req["resourceOauth2ReturnUrl"] = callback_url
if force_authentication:
Expand Down
56 changes: 56 additions & 0 deletions tests/bedrock_agentcore/identity/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ async def test_async_func(param1, access_token=None):
provider_name="test-provider",
agent_identity_token="test-agent-token",
scopes=["read", "write"],
resources=None,
audiences=None,
on_auth_url=None,
auth_flow="M2M",
callback_url=None,
Expand Down Expand Up @@ -163,6 +165,8 @@ async def test_func(param1, my_token=None):
provider_name="test-provider",
agent_identity_token="test-agent-token",
scopes=["read"],
resources=None,
audiences=None,
on_auth_url=None,
auth_flow="M2M",
callback_url=None,
Expand Down Expand Up @@ -205,6 +209,8 @@ def on_auth_url(url):
provider_name="test-provider",
into="token",
scopes=["read", "write"],
resources=["https://backend.example.com/api1", "https://backend.example.com/api2"],
audiences=["urn:example:cooperation"],
on_auth_url=on_auth_url,
auth_flow="USER_FEDERATION",
callback_url="https://example.com/callback",
Expand All @@ -223,6 +229,8 @@ async def test_func(token=None):
provider_name="test-provider",
agent_identity_token="test-agent-token",
scopes=["read", "write"],
resources=["https://backend.example.com/api1", "https://backend.example.com/api2"],
audiences=["urn:example:cooperation"],
on_auth_url=on_auth_url,
auth_flow="USER_FEDERATION",
callback_url="https://example.com/callback",
Expand Down Expand Up @@ -250,6 +258,8 @@ async def test_custom_parameters_passed_to_client(self):
@requires_access_token(
provider_name="test-provider",
scopes=["read"],
resources=None,
audiences=None,
auth_flow="USER_FEDERATION",
custom_parameters=custom_params,
)
Expand All @@ -263,6 +273,8 @@ async def test_func(access_token=None):
provider_name="test-provider",
agent_identity_token="test-agent-token",
scopes=["read"],
resources=None,
audiences=None,
auth_flow="USER_FEDERATION",
callback_url=None,
force_authentication=False,
Expand All @@ -272,6 +284,50 @@ async def test_func(access_token=None):
custom_parameters=custom_params,
)

@pytest.mark.asyncio
@pytest.mark.parametrize("auth_flow", ["M2M", "USER_FEDERATION", "ON_BEHALF_OF_TOKEN_EXCHANGE"])
async def test_all_auth_flows(self, auth_flow):
"""Test decorator forwards each supported auth_flow value to the identity client."""
with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class:
mock_client = Mock()
mock_identity_client_class.return_value = mock_client

with patch(
"bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock
) as mock_get_agent_token:
mock_get_agent_token.return_value = "test-agent-token"
mock_client.get_token = AsyncMock(return_value="test-access-token")

with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"):

@requires_access_token(
provider_name="test-provider",
scopes=["read"],
resources=["https://backend.example.com/api"],
audiences=["urn:example:audience"],
auth_flow=auth_flow,
)
async def test_func(access_token=None):
return access_token

result = await test_func()

assert result == "test-access-token"
mock_client.get_token.assert_called_once_with(
provider_name="test-provider",
agent_identity_token="test-agent-token",
scopes=["read"],
resources=["https://backend.example.com/api"],
audiences=["urn:example:audience"],
on_auth_url=None,
auth_flow=auth_flow,
callback_url=None,
force_authentication=False,
token_poller=None,
custom_state=None,
custom_parameters=None,
)


class TestRequiresIamAccessTokenDecorator:
"""Test the requires_iam_access_token decorator."""
Expand Down
178 changes: 178 additions & 0 deletions tests/bedrock_agentcore/services/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,184 @@ async def test_get_token_with_custom_parameters(self):
customParameters=custom_parameters,
)

@pytest.mark.asyncio
async def test_get_token_with_resources(self):
"""Test get_token forwards the resources list to the data plane request."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client

identity_client = IdentityClient(region)

provider_name = "test-provider"
scopes = ["read"]
agent_identity_token = "test-agent-token"
resources = ["https://backend.example.com/api1", "https://backend.example.com/api2"]
expected_token = "test-access-token"

mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}

result = await identity_client.get_token(
provider_name=provider_name,
scopes=scopes,
resources=resources,
agent_identity_token=agent_identity_token,
auth_flow="M2M",
)

assert result == expected_token
mock_client.get_resource_oauth2_token.assert_called_once_with(
resourceCredentialProviderName=provider_name,
scopes=scopes,
oauth2Flow="M2M",
workloadIdentityToken=agent_identity_token,
resources=resources,
)

@pytest.mark.asyncio
async def test_get_token_with_audiences(self):
"""Test get_token forwards the audiences list to the data plane request."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client

identity_client = IdentityClient(region)

provider_name = "test-provider"
scopes = ["read"]
agent_identity_token = "test-agent-token"
audiences = ["urn:example:cooperation", "urn:example:other"]
expected_token = "test-access-token"

mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}

result = await identity_client.get_token(
provider_name=provider_name,
scopes=scopes,
audiences=audiences,
agent_identity_token=agent_identity_token,
auth_flow="M2M",
)

assert result == expected_token
mock_client.get_resource_oauth2_token.assert_called_once_with(
resourceCredentialProviderName=provider_name,
scopes=scopes,
oauth2Flow="M2M",
workloadIdentityToken=agent_identity_token,
audiences=audiences,
)

@pytest.mark.asyncio
async def test_get_token_with_resources_and_audiences(self):
"""Test get_token forwards both resources and audiences together."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client

identity_client = IdentityClient(region)

provider_name = "test-provider"
scopes = ["read", "write"]
agent_identity_token = "test-agent-token"
resources = ["https://backend.example.com/api"]
audiences = ["urn:example:cooperation"]
expected_token = "test-access-token"

mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}

result = await identity_client.get_token(
provider_name=provider_name,
scopes=scopes,
resources=resources,
audiences=audiences,
agent_identity_token=agent_identity_token,
auth_flow="ON_BEHALF_OF_TOKEN_EXCHANGE",
)

assert result == expected_token
mock_client.get_resource_oauth2_token.assert_called_once_with(
resourceCredentialProviderName=provider_name,
scopes=scopes,
oauth2Flow="ON_BEHALF_OF_TOKEN_EXCHANGE",
workloadIdentityToken=agent_identity_token,
resources=resources,
audiences=audiences,
)

@pytest.mark.asyncio
async def test_get_token_empty_resources_and_audiences_not_forwarded(self):
"""Test get_token does not forward empty/None resources or audiences."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client

identity_client = IdentityClient(region)

provider_name = "test-provider"
scopes = ["read"]
agent_identity_token = "test-agent-token"
expected_token = "test-access-token"

mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}

# Pass empty list and None — neither should appear in the request
result = await identity_client.get_token(
provider_name=provider_name,
scopes=scopes,
resources=[],
audiences=None,
agent_identity_token=agent_identity_token,
auth_flow="M2M",
)

assert result == expected_token
call_kwargs = mock_client.get_resource_oauth2_token.call_args.kwargs
assert "resources" not in call_kwargs
assert "audiences" not in call_kwargs

@pytest.mark.asyncio
@pytest.mark.parametrize("auth_flow", ["M2M", "USER_FEDERATION", "ON_BEHALF_OF_TOKEN_EXCHANGE"])
async def test_get_token_all_auth_flows(self, auth_flow):
"""Test get_token forwards each supported auth_flow to the data plane request."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client

identity_client = IdentityClient(region)

provider_name = "test-provider"
scopes = ["read"]
agent_identity_token = "test-agent-token"
expected_token = "test-access-token"

mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}

result = await identity_client.get_token(
provider_name=provider_name,
scopes=scopes,
agent_identity_token=agent_identity_token,
auth_flow=auth_flow,
)

assert result == expected_token
mock_client.get_resource_oauth2_token.assert_called_once_with(
resourceCredentialProviderName=provider_name,
scopes=scopes,
oauth2Flow=auth_flow,
workloadIdentityToken=agent_identity_token,
)

@pytest.mark.asyncio
async def test_get_api_key_success(self):
"""Test successful API key retrieval."""
Expand Down
Loading