Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions scalekit/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import TypeVar, Optional, Protocol

import grpc
Expand All @@ -8,9 +9,11 @@
from urllib.parse import urlparse

from cryptography.hazmat.primitives import serialization
from grpc_status import rpc_status
from scalekit._version import __version__ as _sdk_version
from scalekit.common.scalekit import GrantType
from scalekit.common.exceptions import ScalekitServerException, ScalekitException
from scalekit.v1.errdetails.errdetails_pb2 import ErrorInfo

TRequest = TypeVar("TRequest")
TResponse = TypeVar("TResponse")
Expand Down Expand Up @@ -150,28 +153,64 @@ def get_headers(self, headers: Optional[dict] = None) -> dict:
return {**default_headers, **headers}
return default_headers

def _token_needs_refresh(self) -> bool:
if not self.access_token or not self.access_token.strip():
return True
try:
claims = jwt.decode(self.access_token, options={"verify_signature": False})
return time.time() >= (float(claims["exp"]) - 300)
except Exception:
return True

@staticmethod
def _is_tool_error(exp: grpc.RpcError) -> bool:
"""Return True if the error originated from tool execution.

Tool errors always carry a populated tool_error_info in ErrorInfo.
This structural check is resilient to server-side error code renames.
"""
try:
status = rpc_status.from_call(exp)
if status is None:
return False
for detail in status.details:
info = ErrorInfo()
if not detail.Unpack(info):
continue
if info.HasField("tool_error_info"):
return True
except Exception:
pass
return False

def grpc_exec(
self,
func: WithCall,
data: TRequest,
retry=2,
_retry: int = 2,
) -> TResponse:
if self._token_needs_refresh():
try:
self.__authenticate_client()
except Exception:
pass # token still valid within buffer; let call proceed and handle reactively
try:
resp = func(
return func(
data,
metadata=tuple(self.get_headers().items()),
)
return resp
except grpc.RpcError as exp:
if exp.code() == grpc.StatusCode.UNAUTHENTICATED:
try:
self.__authenticate_client()
return self.grpc_exec(func, data, retry=retry-1)
except Exception as refresh_exp:
if self._is_tool_error(exp):
raise ScalekitServerException.promote(exp)
elif retry > 0:
return self.grpc_exec(func, data, retry=retry - 1)
else:
raise ScalekitServerException.promote(exp)
if _retry > 0:
try:
self.__authenticate_client()
except Exception:
raise ScalekitServerException.promote(exp)
return self.grpc_exec(func, data, _retry=_retry - 1)
elif _retry > 0:
return self.grpc_exec(func, data, _retry=_retry - 1)
raise ScalekitServerException.promote(exp)
except Exception as exp:
raise ScalekitException(exp)