diff --git a/scalekit/core.py b/scalekit/core.py index e69e60b..a71ee1f 100644 --- a/scalekit/core.py +++ b/scalekit/core.py @@ -1,3 +1,4 @@ +import time from typing import TypeVar, Optional, Protocol import grpc @@ -8,9 +9,11 @@ from urllib.parse import urlparse from cryptography.hazmat.primitives import serialization +from grpc_status import rpc_status from scalekit._version import __version__ as _sdk_version from scalekit.common.scalekit import GrantType from scalekit.common.exceptions import ScalekitServerException, ScalekitException +from scalekit.v1.errdetails.errdetails_pb2 import ErrorInfo TRequest = TypeVar("TRequest") TResponse = TypeVar("TResponse") @@ -150,28 +153,64 @@ def get_headers(self, headers: Optional[dict] = None) -> dict: return {**default_headers, **headers} return default_headers + def _token_needs_refresh(self) -> bool: + if not self.access_token or not self.access_token.strip(): + return True + try: + claims = jwt.decode(self.access_token, options={"verify_signature": False}) + return time.time() >= (float(claims["exp"]) - 300) + except Exception: + return True + + @staticmethod + def _is_tool_error(exp: grpc.RpcError) -> bool: + """Return True if the error originated from tool execution. + + Tool errors always carry a populated tool_error_info in ErrorInfo. + This structural check is resilient to server-side error code renames. + """ + try: + status = rpc_status.from_call(exp) + if status is None: + return False + for detail in status.details: + info = ErrorInfo() + if not detail.Unpack(info): + continue + if info.HasField("tool_error_info"): + return True + except Exception: + pass + return False + def grpc_exec( self, func: WithCall, data: TRequest, - retry=2, + _retry: int = 2, ) -> TResponse: + if self._token_needs_refresh(): + try: + self.__authenticate_client() + except Exception: + pass # token still valid within buffer; let call proceed and handle reactively try: - resp = func( + return func( data, metadata=tuple(self.get_headers().items()), ) - return resp except grpc.RpcError as exp: if exp.code() == grpc.StatusCode.UNAUTHENTICATED: - try: - self.__authenticate_client() - return self.grpc_exec(func, data, retry=retry-1) - except Exception as refresh_exp: + if self._is_tool_error(exp): raise ScalekitServerException.promote(exp) - elif retry > 0: - return self.grpc_exec(func, data, retry=retry - 1) - else: - raise ScalekitServerException.promote(exp) + if _retry > 0: + try: + self.__authenticate_client() + except Exception: + raise ScalekitServerException.promote(exp) + return self.grpc_exec(func, data, _retry=_retry - 1) + elif _retry > 0: + return self.grpc_exec(func, data, _retry=_retry - 1) + raise ScalekitServerException.promote(exp) except Exception as exp: raise ScalekitException(exp)