diff --git a/osc_sdk_python/__init__.py b/osc_sdk_python/__init__.py index 6a0700b..98c3ad4 100644 --- a/osc_sdk_python/__init__.py +++ b/osc_sdk_python/__init__.py @@ -5,6 +5,7 @@ from .outscale_gateway import LOG_MEMORY from .version import get_version from .problem import Problem, ProblemDecoder +from .limiter import RateLimiter # what to Log from .outscale_gateway import LOG_ALL @@ -23,5 +24,6 @@ "LOG_ALL", "LOG_KEEP_ONLY_LAST_REQ", "Problem", - "ProblemDecoder" + "ProblemDecoder", + "RateLimiter", ] diff --git a/osc_sdk_python/authentication.py b/osc_sdk_python/authentication.py index b319e86..23fd8ac 100644 --- a/osc_sdk_python/authentication.py +++ b/osc_sdk_python/authentication.py @@ -5,16 +5,23 @@ from .version import get_version from .credentials import Profile + VERSION: str = get_version() DEFAULT_USER_AGENT = "osc-sdk-python/" + VERSION + class Authentication: - def __init__(self, credentials: Profile, host: str, - method='POST', service='api', - content_type='application/json; charset=utf-8', - algorithm='OSC4-HMAC-SHA256', - signed_headers = 'content-type;host;x-osc-date', - user_agent = DEFAULT_USER_AGENT): + def __init__( + self, + credentials: Profile, + host: str, + method="POST", + service="api", + content_type="application/json; charset=utf-8", + algorithm="OSC4-HMAC-SHA256", + signed_headers="content-type;host;x-osc-date", + user_agent=DEFAULT_USER_AGENT, + ): self.access_key = credentials.access_key self.secret_key = credentials.secret_key self.login = credentials.login @@ -31,34 +38,37 @@ def __init__(self, credentials: Profile, host: str, def forge_headers_signed(self, uri, request_data): date_iso, date = self.build_dates() - credential_scope = '{}/{}/{}/osc4_request'.format(date, self.region, self.service) + credential_scope = "{}/{}/{}/osc4_request".format( + date, self.region, self.service + ) canonical_request = self.build_canonical_request(date_iso, uri, request_data) - str_to_sign = self.create_string_to_sign(date_iso, credential_scope, canonical_request) + str_to_sign = self.create_string_to_sign( + date_iso, credential_scope, canonical_request + ) signature = self.compute_signature(date, str_to_sign) authorisation = self.build_authorization_header(credential_scope, signature) return { - 'Content-Type': self.content_type, - 'X-Osc-Date': date_iso, - 'Authorization': authorisation, - 'User-Agent': self.user_agent, + "Content-Type": self.content_type, + "X-Osc-Date": date_iso, + "Authorization": authorisation, + "User-Agent": self.user_agent, } def build_dates(self): - '''Return YYYYMMDDTHHmmssZ, YYYYMMDD - ''' + """Return YYYYMMDDTHHmmssZ, YYYYMMDD""" t = datetime.datetime.now(datetime.timezone.utc) - return t.strftime('%Y%m%dT%H%M%SZ'), t.strftime('%Y%m%d') + return t.strftime("%Y%m%dT%H%M%SZ"), t.strftime("%Y%m%d") def sign(self, key, msg): return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() def get_signature_key(self, key, date_stamp_value): - k_date = self.sign(('OSC4' + key).encode('utf-8'), date_stamp_value) + k_date = self.sign(("OSC4" + key).encode("utf-8"), date_stamp_value) k_region = self.sign(k_date, self.region) k_service = self.sign(k_region, self.service) - k_signing = self.sign(k_service, 'osc4_request') + k_signing = self.sign(k_service, "osc4_request") return k_signing def build_canonical_request(self, date_iso, canonical_uri, request_data): @@ -81,27 +91,46 @@ def build_canonical_request(self, date_iso, canonical_uri, request_data): # Step 6: Create payload hash. In this example, the payload (body of # the request) contains the request parameters. # Step 7: Combine elements to create canonical request - canonical_querystring = '' - canonical_headers = 'content-type:' + self.content_type + '\n' \ - + 'host:' + self.host + '\n' \ - + 'x-osc-date:' + date_iso + '\n' - payload_hash = hashlib.sha256(request_data.encode('utf-8')).hexdigest() - return self.method + '\n' \ - + canonical_uri + '\n' \ - + canonical_querystring + '\n' \ - + canonical_headers + '\n' \ - + self.signed_headers + '\n' \ - + payload_hash + canonical_querystring = "" + canonical_headers = ( + "content-type:" + + self.content_type + + "\n" + + "host:" + + self.host + + "\n" + + "x-osc-date:" + + date_iso + + "\n" + ) + payload_hash = hashlib.sha256(request_data.encode("utf-8")).hexdigest() + return ( + self.method + + "\n" + + canonical_uri + + "\n" + + canonical_querystring + + "\n" + + canonical_headers + + "\n" + + self.signed_headers + + "\n" + + payload_hash + ) def create_string_to_sign(self, date_iso, credential_scope, canonical_request): # ************* TASK 2: CREATE THE STRING TO SIGN************* # Match the algorithm to the hashing algorithm you use, either SHA-1 or # SHA-256 (recommended) - return self.algorithm + '\n' \ - + date_iso + '\n' \ - + credential_scope + '\n' \ - + hashlib.sha256(canonical_request.encode('utf-8')).hexdigest() - + return ( + self.algorithm + + "\n" + + date_iso + + "\n" + + credential_scope + + "\n" + + hashlib.sha256(canonical_request.encode("utf-8")).hexdigest() + ) def compute_signature(self, date, string_to_sign): # ************* TASK 3: CALCULATE THE SIGNATURE ************* @@ -109,16 +138,27 @@ def compute_signature(self, date, string_to_sign): signing_key = self.get_signature_key(self.secret_key, date) # Sign the string_to_sign using the signing_key - return hmac.new(signing_key, string_to_sign.encode('utf-8'), - hashlib.sha256).hexdigest() - + return hmac.new( + signing_key, string_to_sign.encode("utf-8"), hashlib.sha256 + ).hexdigest() def build_authorization_header(self, credential_scope, signature): # ************* TASK 4: ADD SIGNING INFORMATION TO THE REQUEST ************* # Put the signature information in a header named Authorization. - return self.algorithm + ' ' + 'Credential=' + self.access_key + '/' + credential_scope + ', ' \ - + 'SignedHeaders=' + self.signed_headers + ', ' \ - + 'Signature=' + signature + return ( + self.algorithm + + " " + + "Credential=" + + self.access_key + + "/" + + credential_scope + + ", " + + "SignedHeaders=" + + self.signed_headers + + ", " + + "Signature=" + + signature + ) def is_basic_auth_configured(self): return self.login is not None and self.password is not None @@ -130,7 +170,7 @@ def get_basic_auth_header(self): b64_creds = str(base64.b64encode(creds.encode("utf-8")), "utf-8") date_iso, _ = self.build_dates() return { - 'Content-Type': self.content_type, - 'X-Osc-Date': date_iso, - 'Authorization': "Basic " + b64_creds + "Content-Type": self.content_type, + "X-Osc-Date": date_iso, + "Authorization": "Basic " + b64_creds, } diff --git a/osc_sdk_python/call.py b/osc_sdk_python/call.py index b67424b..617d309 100644 --- a/osc_sdk_python/call.py +++ b/osc_sdk_python/call.py @@ -6,15 +6,18 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry from urllib3.util import parse_url +from datetime import timedelta +from .limiter import RateLimiter import json import warnings -MAX_RETRIES = "3" +MAX_RETRIES = 3 RETRY_BACKOFF_FACTOR = "1" RETRY_BACKOFF_JITTER = "3" RETRY_BACKOFF_MAX = "30" + class Call(object): def __init__(self, logger=None, limiter=None, **kwargs): self.version = kwargs.pop("version", "latest") @@ -22,13 +25,16 @@ def __init__(self, logger=None, limiter=None, **kwargs): self.ssl = kwargs.pop("_ssl", True) self.user_agent = kwargs.pop("user_agent", DEFAULT_USER_AGENT) self.logger = logger - self.limiter = limiter + self.limiter: RateLimiter | None = limiter + self.adapter = None + self.session = Session() + kwargs = self.update_limiter(**kwargs) kwargs = self.update_adapter(**kwargs) self.update_profile(**kwargs) - self.session = Session() - self.session.mount("https://", self.adapter) - self.session.mount("http://", self.adapter) + if self.adapter: + self.session.mount("https://", self.adapter) + self.session.mount("http://", self.adapter) def update_credentials(self, **kwargs): warnings.warn( @@ -39,16 +45,29 @@ def update_credentials(self, **kwargs): return self.update_profile(**kwargs) def update_adapter(self, **kwargs): - self.adapter = HTTPAdapter( - max_retries=Retry( - total=int(kwargs.pop("max_retries", MAX_RETRIES)), - backoff_factor=float(kwargs.pop("retry_backoff_factor", RETRY_BACKOFF_FACTOR)), - backoff_jitter=float(kwargs.pop("retry_backoff_jitter", RETRY_BACKOFF_JITTER)), - backoff_max=float(kwargs.pop("retry_backoff_max", RETRY_BACKOFF_MAX)), - status_forcelist=(400, 429, 500, 503), - allowed_methods=("POST", "GET"), + max_retries: int | str | None = kwargs.pop("max_retries", None) + if max_retries is not None: + max_retries = int(max_retries) + else: + max_retries = MAX_RETRIES + + if max_retries > 0: + self.adapter = HTTPAdapter( + max_retries=Retry( + total=max_retries, + backoff_factor=float( + kwargs.pop("retry_backoff_factor", RETRY_BACKOFF_FACTOR) + ), + backoff_jitter=float( + kwargs.pop("retry_backoff_jitter", RETRY_BACKOFF_JITTER) + ), + backoff_max=float( + kwargs.pop("retry_backoff_max", RETRY_BACKOFF_MAX) + ), + status_forcelist=(400, 429, 500, 503), + allowed_methods=("POST", "GET"), + ) ) - ) return kwargs def update_profile(self, **kwargs): @@ -58,16 +77,13 @@ def update_profile(self, **kwargs): self.profile.merge(Profile(**kwargs)) return kwargs - def update_limiter( - self, - **kwargs - ): + def update_limiter(self, **kwargs): limiter_window = kwargs.pop("limiter_window", None) - if limiter_window is not None: - self.limiter.window = limiter_window + if limiter_window is not None and self.limiter is not None: + self.limiter.window = timedelta(seconds=int(limiter_window)) limiter_max_requests = kwargs.pop("limiter_max_requests", None) - if limiter_max_requests is not None: + if limiter_max_requests is not None and self.limiter is not None: self.limiter.max_requests = limiter_max_requests return kwargs diff --git a/osc_sdk_python/credentials.py b/osc_sdk_python/credentials.py index 866a44f..7318e7e 100644 --- a/osc_sdk_python/credentials.py +++ b/osc_sdk_python/credentials.py @@ -8,8 +8,6 @@ DEFAULT_PROFILE = "default" - - class Endpoint: def __init__(self, **kwargs): self.api: str = kwargs.pop("api", None) @@ -102,7 +100,8 @@ def from_env() -> "Profile": "x509_client_cert_b64": os.environ.get("OSC_X509_CLIENT_CERT_B64"), "x509_client_key": os.environ.get("OSC_X509_CLIENT_KEY"), "x509_client_key_b64": os.environ.get("OSC_X509_CLIENT_KEY_B64"), - "tls_skip_verify": os.environ.get("OSC_TLS_SKIP_VERIFY", "False").lower() in ("true"), + "tls_skip_verify": os.environ.get("OSC_TLS_SKIP_VERIFY", "False").lower() + in ("true"), "login": os.environ.get("OSC_LOGIN"), "password": os.environ.get("OSC_PASSWORD"), "protocol": os.environ.get("OSC_PROTOCOL"), @@ -176,10 +175,12 @@ def from_standard_configuration(path: str, profile: str) -> "Profile": return merged_profile + class Credentials(Profile): def __init__(self, **kwargs): - warnings.warn("Credentials class is deprecated. Use Profile class instead.", + warnings.warn( + "Credentials class is deprecated. Use Profile class instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) - super().__init__(**kwargs) + super().__init__(**kwargs) diff --git a/osc_sdk_python/limiter.py b/osc_sdk_python/limiter.py index cd90bf2..91529f4 100644 --- a/osc_sdk_python/limiter.py +++ b/osc_sdk_python/limiter.py @@ -3,13 +3,14 @@ class RateLimiter: - def __init__(self, window: int, max_requests: int): - self.window = window - self.max_requests = max_requests + def __init__(self, window: timedelta, max_requests: int, datetime_cls=datetime): + self.datetime_cls = datetime_cls + self.window: timedelta = window + self.max_requests: int = max_requests self.requests = [] def acquire(self): - now = datetime.now(timezone.utc) + now = self.datetime_cls.now(timezone.utc) self.clean_old_requests(now) @@ -18,13 +19,11 @@ def acquire(self): wait_time = self.window - (now - oldest) time.sleep(wait_time.total_seconds()) - now = datetime.now(timezone.utc) + now = self.datetime_cls.now(timezone.utc) self.clean_old_requests(now) self.requests.append(now) def clean_old_requests(self, now): - while len(self.requests) > 0 and self.requests[0] <= now - timedelta( - seconds=self.window - ): + while len(self.requests) > 0 and self.requests[0] <= now - self.window: self.requests.pop(0) diff --git a/osc_sdk_python/outscale_gateway.py b/osc_sdk_python/outscale_gateway.py index f546fca..bed36f3 100644 --- a/osc_sdk_python/outscale_gateway.py +++ b/osc_sdk_python/outscale_gateway.py @@ -5,6 +5,7 @@ import ruamel.yaml from .version import get_version import warnings +from datetime import timedelta type_mapping = {"boolean": "bool", "string": "str", "integer": "int", "array": "list"} @@ -19,7 +20,7 @@ LOG_KEEP_ONLY_LAST_REQ = 1 # Default -DEFAULT_LIMITER_WINDOW = 1 # 1 second +DEFAULT_LIMITER_WINDOW = timedelta(seconds=1) # 1 second DEFAULT_LIMITER_MAX_REQUESTS = 5 # 5 requests / sec @@ -114,7 +115,7 @@ def email(self): stacklevel=2, ) return self.login() - + def login(self): return self.call.profile.login diff --git a/osc_sdk_python/problem.py b/osc_sdk_python/problem.py index 3f5f652..50ba89e 100644 --- a/osc_sdk_python/problem.py +++ b/osc_sdk_python/problem.py @@ -1,5 +1,6 @@ import json + class ProblemDecoder(json.JSONDecoder): def decode(self, s): data = super().decode(s) diff --git a/osc_sdk_python/version.py b/osc_sdk_python/version.py index 2bb9905..4f6025f 100644 --- a/osc_sdk_python/version.py +++ b/osc_sdk_python/version.py @@ -1,7 +1,7 @@ import os + def get_version() -> str: osc_sdk_python_path = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(osc_sdk_python_path, "VERSION"), "r") as fd: return fd.read().strip() - diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index c726c30..ba50fe3 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,9 +1,11 @@ import unittest import sys + sys.path.append("..") from osc_sdk_python import Gateway from requests.exceptions import RetryError + class TestExcept(unittest.TestCase): def test_listing(self): @@ -12,5 +14,6 @@ def test_listing(self): with self.assertRaises(RetryError): gw.ReadVms(Filters="a") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_limiter.py b/tests/test_limiter.py new file mode 100644 index 0000000..7eba1b0 --- /dev/null +++ b/tests/test_limiter.py @@ -0,0 +1,65 @@ +import datetime + +from osc_sdk_python import RateLimiter + +i = 0 + + +def test_fast(monkeypatch): + with monkeypatch.context() as m: + was_called = [] + + def mock_sleep(t): + was_called.append(t) + assert t > 0 + return + + i = 0 + + class MockDateTimeFast(datetime.datetime): + @classmethod + def now(cls, tz=None): + global i + i += 1 + return cls(2022, 1, 1, microsecond=i, tzinfo=tz) + + m.setattr("time.sleep", mock_sleep) + + rl = RateLimiter( + datetime.timedelta(seconds=1), 5, datetime_cls=MockDateTimeFast + ) + for i in range(10): + rl.acquire() + + assert len(rl.requests) > 5 + assert len(was_called) > 0 + + +def test_slow(monkeypatch): + with monkeypatch.context() as m: + was_called = [] + + def mock_sleep(t): + was_called.append(t) + assert t > 0 + return + + i = 0 + + class MockDateTimeSlow(datetime.datetime): + @classmethod + def now(cls, tz=None): + global i + i += 1 + return cls(2022 + i, 1, 1, tzinfo=tz) + + m.setattr("time.sleep", mock_sleep) + + rl = RateLimiter( + datetime.timedelta(seconds=1), 5, datetime_cls=MockDateTimeSlow + ) + for i in range(10): + rl.acquire() + + assert len(rl.requests) <= 1 + assert len(was_called) == 0 diff --git a/tests/test_log.py b/tests/test_log.py index 1ca8bff..016a365 100644 --- a/tests/test_log.py +++ b/tests/test_log.py @@ -1,23 +1,27 @@ import unittest import sys + sys.path.append("..") from osc_sdk_python import Gateway, LOG_MEMORY, LOG_KEEP_ONLY_LAST_REQ + class TestLog(unittest.TestCase): def test_listing(self): gw = Gateway() gw.log.config(type=LOG_MEMORY, what=LOG_KEEP_ONLY_LAST_REQ) gw.ReadVms() - self.assertEqual(gw.log.str(), - """uri: /api/v1/ReadVms + self.assertEqual( + gw.log.str(), + """uri: /api/v1/ReadVms payload: -{}""" - ) +{}""", + ) - gw.ReadVms(Filters={'TagKeys': ['test']}) - self.assertEqual(gw.log.str(), - """uri: /api/v1/ReadVms + gw.ReadVms(Filters={"TagKeys": ["test"]}) + self.assertEqual( + gw.log.str(), + """uri: /api/v1/ReadVms payload: { "Filters": { @@ -25,8 +29,9 @@ def test_listing(self): "test" ] } -}""" - ) +}""", + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_manual_aksk.py b/tests/test_manual_aksk.py index 3eee38f..d148a14 100644 --- a/tests/test_manual_aksk.py +++ b/tests/test_manual_aksk.py @@ -1,10 +1,12 @@ import unittest import sys import os + sys.path.append("..") from osc_sdk_python import Gateway import copy + class EnvironManager: def __enter__(self): self.env = copy.deepcopy(os.environ) @@ -12,6 +14,7 @@ def __enter__(self): def __exit__(self, *args): os.environ = self.env + class TestLoginManualAkSk(unittest.TestCase): def test_manual_ak_sk(self): with EnvironManager(): @@ -24,5 +27,6 @@ def test_manual_ak_sk(self): self.assertIsInstance(volumes, dict) self.assertIsInstance(volumes.get("Volumes"), list) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_net.py b/tests/test_net.py index 2458de9..8055c6f 100644 --- a/tests/test_net.py +++ b/tests/test_net.py @@ -1,19 +1,21 @@ import unittest import sys import requests + sys.path.append("..") from osc_sdk_python import Gateway + class TestNet(unittest.TestCase): def test_creation_error(self): gw = Gateway() with self.assertRaises(requests.exceptions.HTTPError) as cm: - gw.CreateNet(IpRange='142.42.42.42/32') + gw.CreateNet(IpRange="142.42.42.42/32") e = cm.exception - errors = e.response.json().get('Errors') + errors = e.response.json().get("Errors") self.assertIsNotNone(errors) self.assertIsInstance(errors, list) for error in errors: - code = error.get('Code') - self.assertEqual(code, '9050') + code = error.get("Code") + self.assertEqual(code, "9050") diff --git a/tests/test_password.py b/tests/test_password.py index 71e1266..e67457b 100644 --- a/tests/test_password.py +++ b/tests/test_password.py @@ -1,10 +1,12 @@ import unittest import sys import os + sys.path.append("..") from osc_sdk_python import Gateway import copy + class EnvironManager: def __enter__(self): self.env = copy.deepcopy(os.environ) @@ -12,14 +14,18 @@ def __enter__(self): def __exit__(self, *args): os.environ = self.env + class TestLoginPassword(unittest.TestCase): - @unittest.skipIf(not (os.environ.get('OSC_TEST_LOGIN') and os.environ.get('OSC_TEST_PASSWORD')), "login/password credentials are not available") + @unittest.skipIf( + not (os.environ.get("OSC_TEST_LOGIN") and os.environ.get("OSC_TEST_PASSWORD")), + "login/password credentials are not available", + ) def test_login(self): with EnvironManager(): os.environ.pop("OSC_ACCESS_KEY", None) os.environ.pop("OSC_SECRET_KEY", None) - email = os.getenv('OSC_TEST_LOGIN') - password = os.getenv('OSC_TEST_PASSWORD') + email = os.getenv("OSC_TEST_LOGIN") + password = os.getenv("OSC_TEST_PASSWORD") self.assertIsNotNone(email, None) self.assertIsNotNone(password, None) gw = Gateway(email=email, password=password) @@ -27,5 +33,6 @@ def test_login(self): self.assertIsInstance(keys, dict) self.assertIsInstance(keys.get("AccessKeys"), list) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_problems.py b/tests/test_problems.py index fc9d468..adcf094 100644 --- a/tests/test_problems.py +++ b/tests/test_problems.py @@ -44,5 +44,5 @@ def test_deserialize_problem2(): """, cls=ProblemDecoder, ) - assert isinstance(obj,Problem) + assert isinstance(obj, Problem) assert obj.type == "https://example.net/validation-error" diff --git a/tests/test_vm.py b/tests/test_vm.py index a9e92b8..bbfc0af 100644 --- a/tests/test_vm.py +++ b/tests/test_vm.py @@ -1,8 +1,10 @@ import unittest import sys + sys.path.append("..") from osc_sdk_python import Gateway + class TestVm(unittest.TestCase): def test_listing(self): @@ -10,12 +12,13 @@ def test_listing(self): vms = gw.ReadVms() self.assertEqual(type(vms), dict) self.assertEqual(type(vms.get("Vms")), list) - + def test_listing_with_context_manager(self): with Gateway() as gw: vms = gw.ReadVms() self.assertEqual(type(vms), dict) self.assertEqual(type(vms.get("Vms")), list) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_volume.py b/tests/test_volume.py index fb599f9..cc60c62 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -1,8 +1,10 @@ import unittest import sys + sys.path.append("..") from osc_sdk_python import Gateway + class TestVolume(unittest.TestCase): def test_listing(self): @@ -11,5 +13,6 @@ def test_listing(self): self.assertEqual(type(volumes), dict) self.assertEqual(type(volumes.get("Volumes")), list) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main()