Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key
import msal.telemetry
from .region import _detect_region
from .throttled_http_client import ThrottledHttpClient
Expand Down Expand Up @@ -1571,6 +1571,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
key_id = kwargs.get("data", {}).get("key_id")
if key_id: # Some token types (SSH-certs, POP) are bound to a key
query["key_id"] = key_id
ext_cache_key = _compute_ext_cache_key(kwargs.get("data", {}))
if ext_cache_key: # FMI tokens need cache isolation by path
query["ext_cache_key"] = ext_cache_key
now = time.time()
refresh_reason = msal.telemetry.AT_ABSENT
for entry in self.token_cache.search( # A generator allows us to
Expand Down Expand Up @@ -2491,6 +2494,32 @@ def remove_tokens_for_client(self):
self.token_cache.remove_at(at)
# acquire_token_for_client() obtains no RTs, so we have no RT to remove

def acquire_token_for_client_with_fmi_path(self, scopes, fmi_path, claims_challenge=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have FMI path as one of the kwargs instead? Creating an entire new API seems pretty heavy.

"""Acquires token for the current confidential client with a Federated Managed Identity (FMI) path.

This is a convenience wrapper around :func:`~acquire_token_for_client`
that attaches the ``fmi_path`` parameter to the token request body.

:param list[str] scopes: (Required)
Scopes requested to access a protected API (a resource).
:param str fmi_path: (Required)
The Federated Managed Identity path to attach to the request.
:param claims_challenge:
The claims_challenge parameter requests specific claims requested by the resource provider
in the form of a claims_challenge directive in the www-authenticate header to be
returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token.
It is a string of a JSON object which contains lists of claims being requested from these locations.

:return: A dict representing the json response from Microsoft Entra:

- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
data = kwargs.pop("data", {})
data["fmi_path"] = fmi_path
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the tokens cached like in MSAL .NET and MSAL GO? I don't think they are.

I see a PR here that look at proper caching #759 but I am not sure it is compliant with the rest.

return self.acquire_token_for_client(
scopes, claims_challenge=claims_challenge, data=data, **kwargs)

def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs):
"""Acquires token using on-behalf-of (OBO) flow.

Expand Down
4 changes: 0 additions & 4 deletions msal/authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,9 @@ def __init__(
self._http_client = http_client
self._oidc_authority_url = oidc_authority_url
if oidc_authority_url:
logger.debug("Initializing with OIDC authority: %s", oidc_authority_url)
tenant_discovery_endpoint = self._initialize_oidc_authority(
oidc_authority_url)
else:
logger.debug("Initializing with Entra authority: %s", authority_url)
tenant_discovery_endpoint = self._initialize_entra_authority(
authority_url, validate_authority, instance_discovery)
try:
Expand All @@ -117,8 +115,6 @@ def __init__(
.format(authority_url)
) + " Also please double check your tenant name or GUID is correct."
raise ValueError(error_message)
logger.debug(
'openid_config("%s") = %s', tenant_discovery_endpoint, openid_config)
self._issuer = openid_config.get('issuer')
self.authorization_endpoint = openid_config['authorization_endpoint']
self.token_endpoint = openid_config['token_endpoint']
Expand Down
83 changes: 80 additions & 3 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import base64
import hashlib
import json
import threading
import time
import logging
Expand All @@ -12,6 +14,63 @@
logger = logging.getLogger(__name__)
_GRANT_TYPE_BROKER = "broker"

# Fields in the request data dict that should NOT be included in the extended
# cache key hash. Everything else in data IS included, because those are extra
# body parameters going on the wire and must differentiate cached tokens.
#
# Excluded fields and reasons:
# - "key_id" : Already handled as a separate cache lookup field
# - "token_type" : Used for SSH-cert/POP detection; AT entry stores it separately
# - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request
# - "claims" : Handled separately; its presence forces a token refresh
# - "scope" : Already represented as "target" in the AT cache key;
# also added to data only at wire-time, not at cache-lookup time
# - "username" : Standard ROPC grant parameter, not an extra body parameter
# - "password" : Standard ROPC grant parameter, not an extra body parameter
#
# Included fields (examples — anything NOT in this set is included):
# - "fmi_path" : Federated Managed Identity credential path
# - any future extra body parameter that should isolate cache entries
_EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset({
"key_id",
"token_type",
"req_cnf",
"claims",
"scope",
"username",
"password",
})


def _compute_ext_cache_key(data):
"""Compute an extended cache key hash from extra body parameters in *data*.

All fields in *data* that go on the wire are included in the hash,
EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``.
This ensures tokens acquired with different parameter values
(e.g., different FMI paths) are cached separately.

Returns an empty string when *data* has no hashable fields.

The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator):
sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded.
"""
if not data:
return ""
cache_components = {
k: str(v) for k, v in data.items()
if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v
}
if not cache_components:
return ""
# Sort keys for consistent hashing (matches Go implementation)
key_str = "".join(
k + cache_components[k] for k in sorted(cache_components.keys())
)
hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest()
return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower()


def is_subdict_of(small, big):
return dict(big, **small) == big

Expand Down Expand Up @@ -59,6 +118,7 @@ def __init__(self):
self.CredentialType.ACCESS_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
realm=None, target=None,
ext_cache_key=None,
# Note: New field(s) can be added here
#key_id=None,
**ignored_payload_from_a_real_token:
Expand All @@ -70,7 +130,8 @@ def __init__(self):
realm or "",
target or "",
#key_id or "", # So ATs of different key_id can coexist
]).lower(),
] + ([ext_cache_key] if ext_cache_key else [])
).lower(),
self.CredentialType.ID_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
realm=None, **ignored_payload_from_a_real_token:
Expand Down Expand Up @@ -98,6 +159,7 @@ def __init__(self):
def _get_access_token(
self,
home_account_id, environment, client_id, realm, target, # Together they form a compound key
ext_cache_key=None,
default=None,
): # O(1)
return self._get(
Expand All @@ -108,6 +170,7 @@ def _get_access_token(
client_id=client_id,
realm=realm,
target=" ".join(target),
ext_cache_key=ext_cache_key,
),
default=default)

Expand Down Expand Up @@ -153,7 +216,8 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
): # Special case for O(1) AT lookup
preferred_result = self._get_access_token(
query["home_account_id"], query["environment"],
query["client_id"], query["realm"], target)
query["client_id"], query["realm"], target,
ext_cache_key=query.get("ext_cache_key"))
if preferred_result and self._is_matching(
preferred_result, query,
# Needs no target_set here because it is satisfied by dict key
Expand All @@ -179,6 +243,13 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
if (entry != preferred_result # Avoid yielding the same entry twice
and self._is_matching(entry, query, target_set=target_set)
):
# Cache isolation for extended cache keys (e.g., FMI path).
# Entries with ext_cache_key must not match queries without one.
if (credential_type == self.CredentialType.ACCESS_TOKEN
and "ext_cache_key" in entry
and "ext_cache_key" not in (query or {})
):
continue
yield entry
for at in expired_access_tokens:
self.remove_at(at)
Expand Down Expand Up @@ -278,6 +349,12 @@ def __add(self, event, now=None):
# So that we won't accidentally store a user's password etc.
"key_id", # It happens in SSH-cert or POP scenario
}})
# Compute and store extended cache key for cache isolation
# (e.g., different FMI paths should have separate cache entries)
ext_cache_key = _compute_ext_cache_key(data)

if ext_cache_key:
at["ext_cache_key"] = ext_cache_key
if "refresh_in" in response:
refresh_in = response["refresh_in"] # It is an integer
at["refresh_on"] = str(now + refresh_in) # Schema wants a string
Expand Down
156 changes: 156 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,162 @@ def test_organizations_authority_should_emit_warning(self):
authority="https://login.microsoftonline.com/organizations")


@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK)
class TestAcquireTokenForClientWithFmiPath(unittest.TestCase):
"""Test that acquire_token_for_client_with_fmi_path attaches fmi_path to HTTP body."""

def test_fmi_path_is_included_in_request_body(self):
app = ConfidentialClientApplication(
"client_id", client_credential="secret",
authority="https://login.microsoftonline.com/my_tenant")
fmi_path = "SomeFmiPath/FmiCredentialPath"
captured_data = {}

def mock_post(url, headers=None, data=None, *args, **kwargs):
captured_data.update(data or {})
return MinimalResponse(
status_code=200, text=json.dumps({
"access_token": "an AT",
"expires_in": 3600,
}))

result = app.acquire_token_for_client_with_fmi_path(
["scope"], fmi_path, post=mock_post)
self.assertIn("access_token", result)
self.assertIn("fmi_path", captured_data,
"fmi_path should be present in the HTTP request body")
self.assertEqual(fmi_path, captured_data["fmi_path"],
"fmi_path value should match the input")

def test_fmi_path_coexists_with_other_data(self):
app = ConfidentialClientApplication(
"client_id", client_credential="secret",
authority="https://login.microsoftonline.com/my_tenant")
fmi_path = "another/fmi/path"
captured_data = {}

def mock_post(url, headers=None, data=None, *args, **kwargs):
captured_data.update(data or {})
return MinimalResponse(
status_code=200, text=json.dumps({
"access_token": "an AT",
"expires_in": 3600,
}))

result = app.acquire_token_for_client_with_fmi_path(
["scope"], fmi_path, post=mock_post)
self.assertIn("access_token", result)
self.assertEqual(fmi_path, captured_data["fmi_path"])
self.assertEqual("client_credentials", captured_data.get("grant_type"))

def test_fmi_path_preserves_existing_data_params(self):
app = ConfidentialClientApplication(
"client_id", client_credential="secret",
authority="https://login.microsoftonline.com/my_tenant")
fmi_path = "my/fmi/path"
captured_data = {}

def mock_post(url, headers=None, data=None, *args, **kwargs):
captured_data.update(data or {})
return MinimalResponse(
status_code=200, text=json.dumps({
"access_token": "an AT",
"expires_in": 3600,
}))

result = app.acquire_token_for_client_with_fmi_path(
["scope"], fmi_path,
data={"extra_key": "extra_value"},
post=mock_post)
self.assertIn("access_token", result)
self.assertEqual(fmi_path, captured_data["fmi_path"])
self.assertEqual("extra_value", captured_data.get("extra_key"),
"Pre-existing data params should be preserved")

def test_cached_token_is_returned_on_second_call(self):
app = ConfidentialClientApplication(
"client_id", client_credential="secret",
authority="https://login.microsoftonline.com/my_tenant")
fmi_path = "SomeFmiPath/FmiCredentialPath"
call_count = [0]

def mock_post(url, headers=None, data=None, *args, **kwargs):
call_count[0] += 1
return MinimalResponse(
status_code=200, text=json.dumps({
"access_token": "an AT",
"expires_in": 3600,
}))

result1 = app.acquire_token_for_client_with_fmi_path(
["scope"], fmi_path, post=mock_post)
self.assertIn("access_token", result1)
self.assertEqual(result1[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP)

result2 = app.acquire_token_for_client_with_fmi_path(
["scope"], fmi_path, post=mock_post)
self.assertIn("access_token", result2)
self.assertEqual(result2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE,
"Second call should return token from cache")

def test_different_fmi_paths_are_cached_separately(self):
"""Tokens acquired with different fmi_path values must NOT share cache entries."""
app = ConfidentialClientApplication(
"client_id", client_credential="secret",
authority="https://login.microsoftonline.com/my_tenant")

def mock_post_factory(token_value):
def mock_post(url, headers=None, data=None, *args, **kwargs):
return MinimalResponse(
status_code=200, text=json.dumps({
"access_token": token_value,
"expires_in": 3600,
}))
return mock_post

# Acquire token with path A
result_a = app.acquire_token_for_client_with_fmi_path(
["scope"], "PathA/credential", post=mock_post_factory("AT_for_path_A"))
self.assertEqual("AT_for_path_A", result_a["access_token"])

# Acquire token with path B (should NOT get path A's cached token)
result_b = app.acquire_token_for_client_with_fmi_path(
["scope"], "PathB/credential", post=mock_post_factory("AT_for_path_B"))
self.assertEqual("AT_for_path_B", result_b["access_token"])
self.assertEqual(result_b[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP,
"Different FMI path should NOT return a cached token from another path")

# Verify path A still returns its own cached token
result_a2 = app.acquire_token_for_client_with_fmi_path(
["scope"], "PathA/credential", post=mock_post_factory("should_not_be_used"))
self.assertEqual("AT_for_path_A", result_a2["access_token"])
self.assertEqual(result_a2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE,
"Same FMI path should return cached token")

def test_fmi_token_does_not_interfere_with_non_fmi_token(self):
"""FMI-cached tokens must not be returned for non-FMI acquire_token_for_client."""
app = ConfidentialClientApplication(
"client_id", client_credential="secret",
authority="https://login.microsoftonline.com/my_tenant")

# First, cache a token via FMI path
app.acquire_token_for_client_with_fmi_path(
["scope"], "some/fmi/path",
post=lambda url, **kwargs: MinimalResponse(
status_code=200, text=json.dumps({
"access_token": "FMI_AT", "expires_in": 3600})))

# Now call regular acquire_token_for_client — should NOT get FMI token
result = app.acquire_token_for_client(
["scope"],
post=lambda url, **kwargs: MinimalResponse(
status_code=200, text=json.dumps({
"access_token": "regular_AT", "expires_in": 3600})))
self.assertEqual("regular_AT", result["access_token"])
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP,
"Non-FMI call should not return FMI-cached token")


@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK)
class TestRemoveTokensForClient(unittest.TestCase):
def test_remove_tokens_for_client_should_remove_client_tokens_only(self):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_ccs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ def test_acquire_token_silent(self):
"CSS routing info should be derived from home_account_id")

def test_acquire_token_by_username_password(self):
import warnings
app = msal.ClientApplication("client_id")
username = "johndoe@contoso.com"
with patch.object(app.http_client, "post", return_value=MinimalResponse(
status_code=400, text='{"error": "mock"}')) as mocked_method:
app.acquire_token_by_username_password(username, "password", ["scope"])
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
app.acquire_token_by_username_password(username, "password", ["scope"])
self.assertEqual(
"upn:" + username,
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
Expand Down
Loading
Loading