diff --git a/src/amplitude_experiment/remote/client.py b/src/amplitude_experiment/remote/client.py index 52baece..8db2653 100644 --- a/src/amplitude_experiment/remote/client.py +++ b/src/amplitude_experiment/remote/client.py @@ -1,3 +1,4 @@ +import base64 import json import threading import time @@ -143,6 +144,10 @@ def __do_fetch(self, user, fetch_options: FetchOptions = None): headers['X-Amp-Exp-Track'] = "track" if fetch_options.tracksAssignment else "no-track" if fetch_options and fetch_options.tracksExposure is not None: headers['X-Amp-Exp-Exposure-Track'] = "track" if fetch_options.tracksExposure else "no-track" + if fetch_options and fetch_options.flagKeys: + headers['X-Amp-Exp-Flag-Keys'] = base64.urlsafe_b64encode( + json.dumps(fetch_options.flagKeys, separators=(",", ":")).encode("utf-8") + ).rstrip(b"=").decode("utf-8") conn = self._connection_pool.acquire() body = user_context.to_json().encode('utf8') diff --git a/src/amplitude_experiment/remote/fetch_options.py b/src/amplitude_experiment/remote/fetch_options.py index 8538888..99a3402 100644 --- a/src/amplitude_experiment/remote/fetch_options.py +++ b/src/amplitude_experiment/remote/fetch_options.py @@ -1,14 +1,24 @@ -from typing import Optional +from typing import List, Optional + + class FetchOptions: - def __init__(self, tracksAssignment: Optional[bool] = None, tracksExposure: Optional[bool] = None): + def __init__( + self, + tracksAssignment: Optional[bool] = None, + tracksExposure: Optional[bool] = None, + flagKeys: Optional[List[str]] = None, + ): """ Fetch Options Parameters: tracksAssignment (Optional[bool]): Whether to track the assignment. The default None uses the server's default behavior (track the assignment event). tracksExposure (Optional[bool]): Whether to track the exposure. The default None uses the server's default behavior (don't track the exposure event). + flagKeys (Optional[List[str]]): Specific flag keys to evaluate and set variants for. """ self.tracksAssignment = tracksAssignment self.tracksExposure = tracksExposure + self.flagKeys = flagKeys def __str__(self): - return f"FetchOptions(tracksAssignment={self.tracksAssignment}, tracksExposure={self.tracksExposure})" + return (f"FetchOptions(tracksAssignment={self.tracksAssignment}, " + f"tracksExposure={self.tracksExposure}, flagKeys={self.flagKeys})") diff --git a/tests/remote/client_test.py b/tests/remote/client_test.py index 43fcf24..45c45b8 100644 --- a/tests/remote/client_test.py +++ b/tests/remote/client_test.py @@ -81,6 +81,25 @@ def test_fetch_with_fetch_options(self): 'X-Amp-Exp-Exposure-Track': 'no-track' }) + mock_conn.request.reset_mock() + + variants = client.fetch_v2(user, FetchOptions(flagKeys=['flag-a', 'flag-b'])) + self.assertIn('sdk-ci-test', variants) + mock_conn.request.assert_called_once_with('POST', '/sdk/v2/vardata?v=0', mock.ANY, { + 'Authorization': f"Api-Key {API_KEY}", + 'Content-Type': 'application/json;charset=utf-8', + 'X-Amp-Exp-Flag-Keys': 'WyJmbGFnLWEiLCJmbGFnLWIiXQ' + }) + + mock_conn.request.reset_mock() + + variants = client.fetch_v2(user, FetchOptions(flagKeys=[])) + self.assertIn('sdk-ci-test', variants) + mock_conn.request.assert_called_once_with('POST', '/sdk/v2/vardata?v=0', mock.ANY, { + 'Authorization': f"Api-Key {API_KEY}", + 'Content-Type': 'application/json;charset=utf-8' + }) + @parameterized.expand([ (300, "Fetch Exception 300", True), (400, "Fetch Exception 400", False),