From e0580f5ff516dff7722459e860f35f9c5bd02db0 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 20 Feb 2026 11:33:37 +0000 Subject: [PATCH 1/2] Added withFmi method for cca app --- msal/application.py | 26 +++++++ tests/test_application.py | 99 ++++++++++++++++++++++++ tests/test_fmi_e2e.py | 153 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 278 insertions(+) create mode 100644 tests/test_fmi_e2e.py diff --git a/msal/application.py b/msal/application.py index ba16df83..612ca4a8 100644 --- a/msal/application.py +++ b/msal/application.py @@ -2491,6 +2491,32 @@ def remove_tokens_for_client(self): self.token_cache.remove_at(at) # acquire_token_for_client() obtains no RTs, so we have no RT to remove + def acquire_token_for_client_with_fmi_path(self, scopes, fmi_path, claims_challenge=None, **kwargs): + """Acquires token for the current confidential client with a Federated Managed Identity (FMI) path. + + This is a convenience wrapper around :func:`~acquire_token_for_client` + that attaches the ``fmi_path`` parameter to the token request body. + + :param list[str] scopes: (Required) + Scopes requested to access a protected API (a resource). + :param str fmi_path: (Required) + The Federated Managed Identity path to attach to the request. + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: A dict representing the json response from Microsoft Entra: + + - A successful response would contain "access_token" key, + - an error response would contain "error" and usually "error_description". + """ + data = kwargs.pop("data", {}) + data["fmi_path"] = fmi_path + return self.acquire_token_for_client( + scopes, claims_challenge=claims_challenge, data=data, **kwargs) + def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs): """Acquires token using on-behalf-of (OBO) flow. diff --git a/tests/test_application.py b/tests/test_application.py index a31c8580..e1822cba 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -707,6 +707,105 @@ def test_organizations_authority_should_emit_warning(self): authority="https://login.microsoftonline.com/organizations") +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenForClientWithFmiPath(unittest.TestCase): + """Test that acquire_token_for_client_with_fmi_path attaches fmi_path to HTTP body.""" + + def test_fmi_path_is_included_in_request_body(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + fmi_path = "SomeFmiPath/FmiCredentialPath" + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an AT", + "expires_in": 3600, + })) + + result = app.acquire_token_for_client_with_fmi_path( + ["scope"], fmi_path, post=mock_post) + self.assertIn("access_token", result) + self.assertIn("fmi_path", captured_data, + "fmi_path should be present in the HTTP request body") + self.assertEqual(fmi_path, captured_data["fmi_path"], + "fmi_path value should match the input") + + def test_fmi_path_coexists_with_other_data(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + fmi_path = "another/fmi/path" + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an AT", + "expires_in": 3600, + })) + + result = app.acquire_token_for_client_with_fmi_path( + ["scope"], fmi_path, post=mock_post) + self.assertIn("access_token", result) + self.assertEqual(fmi_path, captured_data["fmi_path"]) + self.assertEqual("client_credentials", captured_data.get("grant_type")) + + def test_fmi_path_preserves_existing_data_params(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + fmi_path = "my/fmi/path" + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an AT", + "expires_in": 3600, + })) + + result = app.acquire_token_for_client_with_fmi_path( + ["scope"], fmi_path, + data={"extra_key": "extra_value"}, + post=mock_post) + self.assertIn("access_token", result) + self.assertEqual(fmi_path, captured_data["fmi_path"]) + self.assertEqual("extra_value", captured_data.get("extra_key"), + "Pre-existing data params should be preserved") + + def test_cached_token_is_returned_on_second_call(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + fmi_path = "SomeFmiPath/FmiCredentialPath" + call_count = [0] + + def mock_post(url, headers=None, data=None, *args, **kwargs): + call_count[0] += 1 + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an AT", + "expires_in": 3600, + })) + + result1 = app.acquire_token_for_client_with_fmi_path( + ["scope"], fmi_path, post=mock_post) + self.assertIn("access_token", result1) + self.assertEqual(result1[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP) + + result2 = app.acquire_token_for_client_with_fmi_path( + ["scope"], fmi_path, post=mock_post) + self.assertIn("access_token", result2) + self.assertEqual(result2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE, + "Second call should return token from cache") + + @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestRemoveTokensForClient(unittest.TestCase): def test_remove_tokens_for_client_should_remove_client_tokens_only(self): diff --git a/tests/test_fmi_e2e.py b/tests/test_fmi_e2e.py new file mode 100644 index 00000000..d0b0a6ab --- /dev/null +++ b/tests/test_fmi_e2e.py @@ -0,0 +1,153 @@ +"""End-to-end tests for Federated Managed Identity (FMI) functionality. + +These tests verify: +1. Tokens can be acquired using certificate authentication with FMI path +2. Tokens are properly cached and returned from cache on subsequent calls +3. Tokens can be acquired using an assertion callback (RMA pattern) with FMI path + +""" + +import logging +import os +import sys +import unittest + +import msal +from tests.http_client import MinimalHttpClient +from tests.lab_config import get_client_certificate +from tests.test_e2e import LabBasedTestCase + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG if "-v" in sys.argv else logging.INFO) + +# Test configuration +_FMI_TENANT_ID = "f645ad92-e38d-4d1a-b510-d1b09a74a8ca" +_FMI_CLIENT_ID = "4df2cbbb-8612-49c1-87c8-f334d6d065ad" +_FMI_SCOPE = "3091264c-7afb-45d4-b527-39737ee86187/.default" +_FMI_PATH = "SomeFmiPath/FmiCredentialPath" +_FMI_CLIENT_ID_URN = "urn:microsoft:identity:fmi" +_FMI_SCOPE_FOR_RMA = "api://AzureFMITokenExchange/.default" +_AUTHORITY_URL = "https://login.microsoftonline.com/" + _FMI_TENANT_ID + + +def _get_fmi_credential_from_rma(): + """Acquire an FMI token from RMA service using certificate credentials. + + This mirrors the Go function GetFmiCredentialFromRma: + 1. Create a confidential client with certificate credential + 2. Acquire a token for the FMI scope with the FMI path + 3. Return the access token as an assertion string + """ + + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + result = app.acquire_token_for_client_with_fmi_path( + [_FMI_SCOPE_FOR_RMA], _FMI_PATH) + if "access_token" not in result: + raise RuntimeError( + "Failed to acquire FMI token from RMA: {}: {}".format( + result.get("error"), result.get("error_description"))) + return result["access_token"] + + +class TestFMIBasicFunctionality(LabBasedTestCase): + """Test basic FMI token acquisition with certificate credential. + + Mirrors TestFMIBasicFunctionality from Go: + 1. Acquire token by credential with FMI path + 2. Verify silent (cached) token acquisition works + 3. Validate tokens match (proving cache was used) + """ + + def test_acquire_and_cache_with_fmi_path(self): + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + + # 1. Acquire token by credential with FMI path + result = app.acquire_token_for_client_with_fmi_path(scopes, _FMI_PATH) + self.assertIn("access_token", result, + "acquire_token_for_client_with_fmi_path() failed: {}: {}".format( + result.get("error"), result.get("error_description"))) + self.assertNotEqual("", result["access_token"], + "acquire_token_for_client_with_fmi_path() returned empty access token") + + first_token = result["access_token"] + + # 2. Verify silent token acquisition works (should retrieve from cache) + cache_result = app.acquire_token_for_client_with_fmi_path(scopes, _FMI_PATH) + self.assertIn("access_token", cache_result, + "Second call failed: {}: {}".format( + cache_result.get("error"), cache_result.get("error_description"))) + self.assertNotEqual("", cache_result["access_token"], + "Second call returned empty access token") + self.assertEqual( + cache_result.get("token_source"), "cache", + "Second call should return token from cache") + + # 3. Validate tokens match (proving cache was used) + self.assertEqual(first_token, cache_result["access_token"], + "Token comparison failed - tokens don't match, " + "cache might not be working correctly") + +class TestFMIIntegration(LabBasedTestCase): + """Test FMI with assertion callback (RMA pattern). + + Mirrors TestFMIIntegration from Go: + 1. Get credentials from RMA via assertion callback + 2. Acquire token by credential with FMI path + 3. Verify cached token acquisition works + 4. Compare tokens to verify cache was used + """ + + def test_acquire_with_assertion_callback_and_fmi_path(self): + # Create credential from assertion callback (mirrors Go's NewCredFromAssertionCallback) + client_credential = { + "client_assertion": lambda: _get_fmi_credential_from_rma(), + } + + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID_URN, + client_credential=client_credential, + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + fmi_path = "SomeFmiPath/Path" + + # 1. Acquire token by credential with FMI path + result = app.acquire_token_for_client_with_fmi_path(scopes, fmi_path) + self.assertIn("access_token", result, + "acquire_token_for_client_with_fmi_path() failed: {}: {}".format( + result.get("error"), result.get("error_description"))) + self.assertNotEqual("", result["access_token"], + "acquire_token_for_client_with_fmi_path() returned empty access token") + first_token = result["access_token"] + + # 2. Verify cached token acquisition works + cache_result = app.acquire_token_for_client_with_fmi_path(scopes, fmi_path) + self.assertIn("access_token", cache_result, + "Second call failed: {}: {}".format( + cache_result.get("error"), cache_result.get("error_description"))) + self.assertNotEqual("", cache_result["access_token"], + "Second call returned empty access token") + self.assertEqual( + cache_result.get("token_source"), "cache", + "Second call should return token from cache") + + # 3. Compare tokens to verify cache was used + self.assertEqual(first_token, cache_result["access_token"], + "Token comparison failed - tokens don't match, " + "cache might not be working correctly") + + +if __name__ == "__main__": + unittest.main() From 791161dd214803714cf1c7711681741f9d8e94ab Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 23 Feb 2026 10:58:27 +0000 Subject: [PATCH 2/2] Added Cache support for fmi keys --- msal/application.py | 5 +- msal/authority.py | 4 -- msal/token_cache.py | 83 +++++++++++++++++++++++- tests/test_application.py | 57 +++++++++++++++++ tests/test_ccs.py | 5 +- tests/test_fmi_e2e.py | 130 ++++++++++++++++++++++++++++++++++++++ tests/test_token_cache.py | 123 +++++++++++++++++++++++++++++++++++- 7 files changed, 397 insertions(+), 10 deletions(-) diff --git a/msal/application.py b/msal/application.py index 612ca4a8..c08e358e 100644 --- a/msal/application.py +++ b/msal/application.py @@ -15,7 +15,7 @@ from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request from .wstrust_response import * -from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER +from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key import msal.telemetry from .region import _detect_region from .throttled_http_client import ThrottledHttpClient @@ -1571,6 +1571,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( key_id = kwargs.get("data", {}).get("key_id") if key_id: # Some token types (SSH-certs, POP) are bound to a key query["key_id"] = key_id + ext_cache_key = _compute_ext_cache_key(kwargs.get("data", {})) + if ext_cache_key: # FMI tokens need cache isolation by path + query["ext_cache_key"] = ext_cache_key now = time.time() refresh_reason = msal.telemetry.AT_ABSENT for entry in self.token_cache.search( # A generator allows us to diff --git a/msal/authority.py b/msal/authority.py index b114831f..4a3a56ee 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -92,11 +92,9 @@ def __init__( self._http_client = http_client self._oidc_authority_url = oidc_authority_url if oidc_authority_url: - logger.debug("Initializing with OIDC authority: %s", oidc_authority_url) tenant_discovery_endpoint = self._initialize_oidc_authority( oidc_authority_url) else: - logger.debug("Initializing with Entra authority: %s", authority_url) tenant_discovery_endpoint = self._initialize_entra_authority( authority_url, validate_authority, instance_discovery) try: @@ -117,8 +115,6 @@ def __init__( .format(authority_url) ) + " Also please double check your tenant name or GUID is correct." raise ValueError(error_message) - logger.debug( - 'openid_config("%s") = %s', tenant_discovery_endpoint, openid_config) self._issuer = openid_config.get('issuer') self.authorization_endpoint = openid_config['authorization_endpoint'] self.token_endpoint = openid_config['token_endpoint'] diff --git a/msal/token_cache.py b/msal/token_cache.py index 846c8132..b32fdfef 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -1,4 +1,6 @@ -import json +import base64 +import hashlib +import json import threading import time import logging @@ -12,6 +14,63 @@ logger = logging.getLogger(__name__) _GRANT_TYPE_BROKER = "broker" +# Fields in the request data dict that should NOT be included in the extended +# cache key hash. Everything else in data IS included, because those are extra +# body parameters going on the wire and must differentiate cached tokens. +# +# Excluded fields and reasons: +# - "key_id" : Already handled as a separate cache lookup field +# - "token_type" : Used for SSH-cert/POP detection; AT entry stores it separately +# - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request +# - "claims" : Handled separately; its presence forces a token refresh +# - "scope" : Already represented as "target" in the AT cache key; +# also added to data only at wire-time, not at cache-lookup time +# - "username" : Standard ROPC grant parameter, not an extra body parameter +# - "password" : Standard ROPC grant parameter, not an extra body parameter +# +# Included fields (examples — anything NOT in this set is included): +# - "fmi_path" : Federated Managed Identity credential path +# - any future extra body parameter that should isolate cache entries +_EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset({ + "key_id", + "token_type", + "req_cnf", + "claims", + "scope", + "username", + "password", +}) + + +def _compute_ext_cache_key(data): + """Compute an extended cache key hash from extra body parameters in *data*. + + All fields in *data* that go on the wire are included in the hash, + EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``. + This ensures tokens acquired with different parameter values + (e.g., different FMI paths) are cached separately. + + Returns an empty string when *data* has no hashable fields. + + The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator): + sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded. + """ + if not data: + return "" + cache_components = { + k: str(v) for k, v in data.items() + if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v + } + if not cache_components: + return "" + # Sort keys for consistent hashing (matches Go implementation) + key_str = "".join( + k + cache_components[k] for k in sorted(cache_components.keys()) + ) + hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest() + return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower() + + def is_subdict_of(small, big): return dict(big, **small) == big @@ -59,6 +118,7 @@ def __init__(self): self.CredentialType.ACCESS_TOKEN: lambda home_account_id=None, environment=None, client_id=None, realm=None, target=None, + ext_cache_key=None, # Note: New field(s) can be added here #key_id=None, **ignored_payload_from_a_real_token: @@ -70,7 +130,8 @@ def __init__(self): realm or "", target or "", #key_id or "", # So ATs of different key_id can coexist - ]).lower(), + ] + ([ext_cache_key] if ext_cache_key else []) + ).lower(), self.CredentialType.ID_TOKEN: lambda home_account_id=None, environment=None, client_id=None, realm=None, **ignored_payload_from_a_real_token: @@ -98,6 +159,7 @@ def __init__(self): def _get_access_token( self, home_account_id, environment, client_id, realm, target, # Together they form a compound key + ext_cache_key=None, default=None, ): # O(1) return self._get( @@ -108,6 +170,7 @@ def _get_access_token( client_id=client_id, realm=realm, target=" ".join(target), + ext_cache_key=ext_cache_key, ), default=default) @@ -153,7 +216,8 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n) ): # Special case for O(1) AT lookup preferred_result = self._get_access_token( query["home_account_id"], query["environment"], - query["client_id"], query["realm"], target) + query["client_id"], query["realm"], target, + ext_cache_key=query.get("ext_cache_key")) if preferred_result and self._is_matching( preferred_result, query, # Needs no target_set here because it is satisfied by dict key @@ -179,6 +243,13 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n) if (entry != preferred_result # Avoid yielding the same entry twice and self._is_matching(entry, query, target_set=target_set) ): + # Cache isolation for extended cache keys (e.g., FMI path). + # Entries with ext_cache_key must not match queries without one. + if (credential_type == self.CredentialType.ACCESS_TOKEN + and "ext_cache_key" in entry + and "ext_cache_key" not in (query or {}) + ): + continue yield entry for at in expired_access_tokens: self.remove_at(at) @@ -278,6 +349,12 @@ def __add(self, event, now=None): # So that we won't accidentally store a user's password etc. "key_id", # It happens in SSH-cert or POP scenario }}) + # Compute and store extended cache key for cache isolation + # (e.g., different FMI paths should have separate cache entries) + ext_cache_key = _compute_ext_cache_key(data) + + if ext_cache_key: + at["ext_cache_key"] = ext_cache_key if "refresh_in" in response: refresh_in = response["refresh_in"] # It is an integer at["refresh_on"] = str(now + refresh_in) # Schema wants a string diff --git a/tests/test_application.py b/tests/test_application.py index e1822cba..040ecdff 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -805,6 +805,63 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): self.assertEqual(result2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE, "Second call should return token from cache") + def test_different_fmi_paths_are_cached_separately(self): + """Tokens acquired with different fmi_path values must NOT share cache entries.""" + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + + def mock_post_factory(token_value): + def mock_post(url, headers=None, data=None, *args, **kwargs): + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": token_value, + "expires_in": 3600, + })) + return mock_post + + # Acquire token with path A + result_a = app.acquire_token_for_client_with_fmi_path( + ["scope"], "PathA/credential", post=mock_post_factory("AT_for_path_A")) + self.assertEqual("AT_for_path_A", result_a["access_token"]) + + # Acquire token with path B (should NOT get path A's cached token) + result_b = app.acquire_token_for_client_with_fmi_path( + ["scope"], "PathB/credential", post=mock_post_factory("AT_for_path_B")) + self.assertEqual("AT_for_path_B", result_b["access_token"]) + self.assertEqual(result_b[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, + "Different FMI path should NOT return a cached token from another path") + + # Verify path A still returns its own cached token + result_a2 = app.acquire_token_for_client_with_fmi_path( + ["scope"], "PathA/credential", post=mock_post_factory("should_not_be_used")) + self.assertEqual("AT_for_path_A", result_a2["access_token"]) + self.assertEqual(result_a2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE, + "Same FMI path should return cached token") + + def test_fmi_token_does_not_interfere_with_non_fmi_token(self): + """FMI-cached tokens must not be returned for non-FMI acquire_token_for_client.""" + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + + # First, cache a token via FMI path + app.acquire_token_for_client_with_fmi_path( + ["scope"], "some/fmi/path", + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "FMI_AT", "expires_in": 3600}))) + + # Now call regular acquire_token_for_client — should NOT get FMI token + result = app.acquire_token_for_client( + ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "regular_AT", "expires_in": 3600}))) + self.assertEqual("regular_AT", result["access_token"]) + self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, + "Non-FMI call should not return FMI-cached token") + @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestRemoveTokensForClient(unittest.TestCase): diff --git a/tests/test_ccs.py b/tests/test_ccs.py index 8b801773..9bbc2787 100644 --- a/tests/test_ccs.py +++ b/tests/test_ccs.py @@ -61,11 +61,14 @@ def test_acquire_token_silent(self): "CSS routing info should be derived from home_account_id") def test_acquire_token_by_username_password(self): + import warnings app = msal.ClientApplication("client_id") username = "johndoe@contoso.com" with patch.object(app.http_client, "post", return_value=MinimalResponse( status_code=400, text='{"error": "mock"}')) as mocked_method: - app.acquire_token_by_username_password(username, "password", ["scope"]) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + app.acquire_token_by_username_password(username, "password", ["scope"]) self.assertEqual( "upn:" + username, mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), diff --git a/tests/test_fmi_e2e.py b/tests/test_fmi_e2e.py index d0b0a6ab..6ba2f09f 100644 --- a/tests/test_fmi_e2e.py +++ b/tests/test_fmi_e2e.py @@ -149,5 +149,135 @@ def test_acquire_with_assertion_callback_and_fmi_path(self): "cache might not be working correctly") +class TestFMICacheIsolation(LabBasedTestCase): + """Test that tokens acquired with different FMI paths are cached separately. + + This verifies the cache key extensibility: two calls with different fmi_path + values should NOT return each other's cached tokens. + """ + + def test_different_fmi_paths_are_cached_separately(self): + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + + # Acquire token with path A + result_a = app.acquire_token_for_client_with_fmi_path( + scopes, "PathA/credential") + self.assertIn("access_token", result_a, + "Path A acquisition failed: {}: {}".format( + result_a.get("error"), result_a.get("error_description"))) + + # Acquire token with path B — should NOT get path A's cached token + result_b = app.acquire_token_for_client_with_fmi_path( + scopes, "PathB/credential") + self.assertIn("access_token", result_b, + "Path B acquisition failed: {}: {}".format( + result_b.get("error"), result_b.get("error_description"))) + self.assertNotEqual( + result_b.get("token_source"), "cache", + "Different FMI path should NOT return cached token from another path") + + # Verify path A still returns its own cached token + result_a2 = app.acquire_token_for_client_with_fmi_path( + scopes, "PathA/credential") + self.assertIn("access_token", result_a2) + self.assertEqual( + result_a2.get("token_source"), "cache", + "Same FMI path should return cached token") + self.assertEqual(result_a["access_token"], result_a2["access_token"]) + + def test_fmi_token_does_not_interfere_with_non_fmi_token(self): + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + + # Cache a token via FMI path + fmi_result = app.acquire_token_for_client_with_fmi_path(scopes, _FMI_PATH) + self.assertIn("access_token", fmi_result) + + # Regular acquire_token_for_client should NOT get the FMI token + regular_result = app.acquire_token_for_client(scopes) + self.assertIn("access_token", regular_result, + "Regular call failed: {}: {}".format( + regular_result.get("error"), regular_result.get("error_description"))) + self.assertNotEqual( + regular_result.get("token_source"), "cache", + "Non-FMI call should not return FMI-cached token") + + +class TestFMICacheInspection(LabBasedTestCase): + """Acquire tokens with two different FMI paths and inspect the underlying + cache to verify the entries are correctly isolated.""" + + def test_two_fmi_paths_produce_separate_cache_entries(self): + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + path_a = "PathAlpha/Credential" + path_b = "PathBeta/Credential" + + # 1. Acquire token with path A + result_a = app.acquire_token_for_client_with_fmi_path(scopes, path_a) + self.assertIn("access_token", result_a, + "Path A acquisition failed: {}: {}".format( + result_a.get("error"), result_a.get("error_description"))) + token_a = result_a["access_token"] + + # 2. Acquire token with path B + result_b = app.acquire_token_for_client_with_fmi_path(scopes, path_b) + self.assertIn("access_token", result_b, + "Path B acquisition failed: {}: {}".format( + result_b.get("error"), result_b.get("error_description"))) + token_b = result_b["access_token"] + + # Tokens should be different (different paths go to different resources) + self.assertNotEqual(token_a, token_b, + "Tokens for different FMI paths should differ") + + # 3. Inspect cache: there should be exactly 2 AccessToken entries + cache = app.token_cache._cache + at_entries = cache.get("AccessToken", {}) + # Filter to our client_id + scope to avoid noise + our_entries = { + k: v for k, v in at_entries.items() + if v.get("client_id") == _FMI_CLIENT_ID + and _FMI_SCOPE.split("/")[0] in v.get("target", "") + } + self.assertEqual(2, len(our_entries), + "Cache should contain exactly 2 AT entries for our client, " + "got {}: {}".format(len(our_entries), list(our_entries.keys()))) + + # 4. Each entry must have a non-empty ext_cache_key, and they must differ + ext_keys = [v.get("ext_cache_key") for v in our_entries.values()] + for ek in ext_keys: + self.assertTrue(ek, "Each FMI cache entry must have a non-empty ext_cache_key") + self.assertNotEqual(ext_keys[0], ext_keys[1], + "ext_cache_key values for different FMI paths must differ") + + # 5. Verify each path still returns its own cached token + cached_a = app.acquire_token_for_client_with_fmi_path(scopes, path_a) + self.assertEqual("cache", cached_a.get("token_source")) + self.assertEqual(token_a, cached_a["access_token"], + "Path A should return its own cached token") + + cached_b = app.acquire_token_for_client_with_fmi_path(scopes, path_b) + self.assertEqual("cache", cached_b.get("token_source")) + self.assertEqual(token_b, cached_b["access_token"], + "Path B should return its own cached token") + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 5310b789..8c464760 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -4,7 +4,7 @@ import time import warnings -from msal.token_cache import TokenCache, SerializableTokenCache +from msal.token_cache import TokenCache, SerializableTokenCache, _compute_ext_cache_key from tests import unittest @@ -321,3 +321,124 @@ def tearDown(self): output.get("AccessToken", {}).get("an-entry"), {"foo": "bar"}, "Undefined token keys and their values should be intact") + +class TestComputeExtCacheKey(unittest.TestCase): + """Tests for the _compute_ext_cache_key hash function.""" + + def test_empty_data_returns_empty_string(self): + self.assertEqual("", _compute_ext_cache_key(None)) + self.assertEqual("", _compute_ext_cache_key({})) + + def test_excluded_fields_are_ignored(self): + self.assertEqual("", _compute_ext_cache_key({"key_id": "k1", "token_type": "ssh-cert", "req_cnf": "nonce", "claims": "{}"}), + "Fields in _EXT_CACHE_KEY_EXCLUDED_FIELDS should produce an empty hash") + + def test_fmi_path_produces_non_empty_hash(self): + result = _compute_ext_cache_key({"fmi_path": "SomePath/Credential"}) + self.assertNotEqual("", result) + self.assertIsInstance(result, str) + + def test_same_input_produces_same_hash(self): + h1 = _compute_ext_cache_key({"fmi_path": "path/a"}) + h2 = _compute_ext_cache_key({"fmi_path": "path/a"}) + self.assertEqual(h1, h2) + + def test_different_fmi_paths_produce_different_hashes(self): + h1 = _compute_ext_cache_key({"fmi_path": "path/a"}) + h2 = _compute_ext_cache_key({"fmi_path": "path/b"}) + self.assertNotEqual(h1, h2) + + def test_empty_fmi_path_value_is_ignored(self): + self.assertEqual("", _compute_ext_cache_key({"fmi_path": ""})) + + def test_excluded_fields_dont_affect_hash(self): + h1 = _compute_ext_cache_key({"fmi_path": "path/a"}) + h2 = _compute_ext_cache_key({"fmi_path": "path/a", "key_id": "k1", "req_cnf": "nonce"}) + self.assertEqual(h1, h2, "Excluded fields should not affect the hash") + + def test_non_excluded_fields_are_included_in_hash(self): + h1 = _compute_ext_cache_key({"fmi_path": "path/a"}) + h2 = _compute_ext_cache_key({"fmi_path": "path/a", "custom_param": "val"}) + self.assertNotEqual(h1, h2, "Non-excluded fields should change the hash") + + +class TestExtCacheKeyIsolation(unittest.TestCase): + """Tests that ext_cache_key provides proper cache isolation in TokenCache.""" + + def _build_event(self, client_id, scope, token_endpoint, access_token, data=None, **kwargs): + return { + "client_id": client_id, + "scope": scope, + "token_endpoint": token_endpoint, + "response": build_response(access_token=access_token, expires_in=3600), + "data": data or {}, + **kwargs, + } + + def test_at_key_includes_ext_cache_key_when_present(self): + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + key_without = key_maker( + home_account_id="", environment="env", client_id="cid", + realm="realm", target="scope") + key_with = key_maker( + home_account_id="", environment="env", client_id="cid", + realm="realm", target="scope", ext_cache_key="somehash") + self.assertNotEqual(key_without, key_with, + "Keys with and without ext_cache_key should differ") + self.assertIn("somehash", key_with) + + def test_different_ext_cache_keys_produce_different_at_keys(self): + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + key_a = key_maker( + home_account_id="", environment="env", client_id="cid", + realm="realm", target="scope", ext_cache_key="hash_a") + key_b = key_maker( + home_account_id="", environment="env", client_id="cid", + realm="realm", target="scope", ext_cache_key="hash_b") + self.assertNotEqual(key_a, key_b) + + def test_fmi_tokens_are_stored_with_ext_cache_key(self): + cache = TokenCache() + event = self._build_event( + "cid", ["s1"], "https://login.example.com/tenant/v2/token", + "fmi_token", data={"fmi_path": "some/path"}) + cache.add(event) + at_entries = list(cache.search(TokenCache.CredentialType.ACCESS_TOKEN, target=["s1"])) + self.assertEqual(0, len(at_entries), + "FMI tokens should NOT be found by a query without ext_cache_key") + + def test_fmi_tokens_found_with_matching_ext_cache_key_query(self): + cache = TokenCache() + ext_key = _compute_ext_cache_key({"fmi_path": "some/path"}) + event = self._build_event( + "cid", ["s1"], "https://login.example.com/tenant/v2/token", + "fmi_token", data={"fmi_path": "some/path"}) + cache.add(event) + at_entries = list(cache.search( + TokenCache.CredentialType.ACCESS_TOKEN, target=["s1"], + query={"client_id": "cid", "environment": "login.example.com", + "realm": "tenant", "home_account_id": None, + "ext_cache_key": ext_key})) + self.assertEqual(1, len(at_entries)) + self.assertEqual("fmi_token", at_entries[0]["secret"]) + + def test_non_fmi_tokens_not_affected_by_fmi_cache(self): + cache = TokenCache() + # Add FMI token + cache.add(self._build_event( + "cid", ["s1"], "https://login.example.com/tenant/v2/token", + "fmi_token", data={"fmi_path": "some/path"})) + # Add regular token + cache.add(self._build_event( + "cid", ["s1"], "https://login.example.com/tenant/v2/token", + "regular_token")) + # Search without ext_cache_key should find only regular token + at_entries = list(cache.search( + TokenCache.CredentialType.ACCESS_TOKEN, target=["s1"], + query={"client_id": "cid", "environment": "login.example.com", + "realm": "tenant", "home_account_id": None})) + self.assertEqual(1, len(at_entries)) + self.assertEqual("regular_token", at_entries[0]["secret"]) +