From 79bb0d24ee3ce0393c3f461a1f14295c134264c1 Mon Sep 17 00:00:00 2001 From: Cesare Naldi <3353250+cesarenaldi@users.noreply.github.com> Date: Sat, 20 Jun 2026 14:42:54 +0200 Subject: [PATCH] fix(client): split rejected rpc batches --- .../_internal/actions/relayer/approvals.py | 68 +++++--- src/polymarket/_internal/eoa/rpc.py | 165 +++++++++++++++--- tests/unit/_relayer_helpers.py | 126 +++++++++---- tests/unit/test_eoa_broadcast.py | 49 ++++-- tests/unit/test_eoa_rpc.py | 80 +++++++++ 5 files changed, 384 insertions(+), 104 deletions(-) create mode 100644 tests/unit/test_eoa_rpc.py diff --git a/src/polymarket/_internal/actions/relayer/approvals.py b/src/polymarket/_internal/actions/relayer/approvals.py index 34c6709..a3b2795 100644 --- a/src/polymarket/_internal/actions/relayer/approvals.py +++ b/src/polymarket/_internal/actions/relayer/approvals.py @@ -35,16 +35,29 @@ async def resolve_missing_trading_approval_calls( rpc: JsonRpcClient, *, wallet: EvmAddress, environment: Environment ) -> list[TransactionCall]: erc20, erc1155 = _required_trading_approvals(environment) - erc20_missing: list[TransactionCall] = [] - for approval in erc20: - check = erc20_allowance_call( + erc20_checks = [ + erc20_allowance_call( token_address=approval.token_address, owner=wallet, spender=approval.spender, ) - allowance = decode_erc20_allowance_result( - await rpc.eth_call(to=str(check.to), data=check.data) + for approval in erc20 + ] + erc1155_checks = [ + erc1155_is_approved_for_all_call( + token_address=approval.token_address, + owner=wallet, + operator=approval.operator, ) + for approval in erc1155 + ] + results = await rpc.eth_call_batch( + [(str(check.to), check.data) for check in [*erc20_checks, *erc1155_checks]] + ) + + erc20_missing: list[TransactionCall] = [] + for approval, result in zip(erc20, results[: len(erc20)], strict=True): + allowance = decode_erc20_allowance_result(result) if allowance < approval.amount: erc20_missing.append( erc20_approval_call( @@ -55,15 +68,8 @@ async def resolve_missing_trading_approval_calls( ) erc1155_missing: list[TransactionCall] = [] - for approval in erc1155: - check = erc1155_is_approved_for_all_call( - token_address=approval.token_address, - owner=wallet, - operator=approval.operator, - ) - approved = decode_erc1155_is_approved_for_all_result( - await rpc.eth_call(to=str(check.to), data=check.data) - ) + for approval, result in zip(erc1155, results[len(erc20) :], strict=True): + approved = decode_erc1155_is_approved_for_all_result(result) if not approved: erc1155_missing.append( erc1155_set_approval_for_all_call( @@ -80,14 +86,29 @@ def resolve_missing_trading_approval_calls_sync( rpc: SyncJsonRpcClient, *, wallet: EvmAddress, environment: Environment ) -> list[TransactionCall]: erc20, erc1155 = _required_trading_approvals(environment) - erc20_missing: list[TransactionCall] = [] - for approval in erc20: - check = erc20_allowance_call( + erc20_checks = [ + erc20_allowance_call( token_address=approval.token_address, owner=wallet, spender=approval.spender, ) - allowance = decode_erc20_allowance_result(rpc.eth_call(to=str(check.to), data=check.data)) + for approval in erc20 + ] + erc1155_checks = [ + erc1155_is_approved_for_all_call( + token_address=approval.token_address, + owner=wallet, + operator=approval.operator, + ) + for approval in erc1155 + ] + results = rpc.eth_call_batch( + [(str(check.to), check.data) for check in [*erc20_checks, *erc1155_checks]] + ) + + erc20_missing: list[TransactionCall] = [] + for approval, result in zip(erc20, results[: len(erc20)], strict=True): + allowance = decode_erc20_allowance_result(result) if allowance < approval.amount: erc20_missing.append( erc20_approval_call( @@ -98,15 +119,8 @@ def resolve_missing_trading_approval_calls_sync( ) erc1155_missing: list[TransactionCall] = [] - for approval in erc1155: - check = erc1155_is_approved_for_all_call( - token_address=approval.token_address, - owner=wallet, - operator=approval.operator, - ) - approved = decode_erc1155_is_approved_for_all_result( - rpc.eth_call(to=str(check.to), data=check.data) - ) + for approval, result in zip(erc1155, results[len(erc20) :], strict=True): + approved = decode_erc1155_is_approved_for_all_result(result) if not approved: erc1155_missing.append( erc1155_set_approval_for_all_call( diff --git a/src/polymarket/_internal/eoa/rpc.py b/src/polymarket/_internal/eoa/rpc.py index dbccaf3..e5e854e 100644 --- a/src/polymarket/_internal/eoa/rpc.py +++ b/src/polymarket/_internal/eoa/rpc.py @@ -1,13 +1,16 @@ from __future__ import annotations +import asyncio import json as _json -from typing import Any, cast +from collections.abc import Sequence +from typing import Any, TypeAlias, cast from polymarket.clients._transport import AsyncTransport, SyncTransport from polymarket.errors import RequestRejectedError, UnexpectedResponseError, UserInputError _JSON_RPC_REVERT_CODES = frozenset({3, -32_000, -32_003, -32_015, -32_603}) _JSON_RPC_REVERT_TOKENS = ("execution reverted", "revert", "invalid opcode") +EthCallBatchRequest: TypeAlias = tuple[str, str] class JsonRpcCallError(RequestRejectedError): @@ -65,22 +68,18 @@ async def verify_chain_id(self, expected: int) -> None: self._verified_chain_id = actual async def _call(self, method: str, params: list[Any]) -> Any: + envelope = self._build_envelope(method, params) + raw = await self._transport.post_json("", json=envelope) + return _parse_rpc_response(method, raw) + + def _build_envelope(self, method: str, params: list[Any]) -> dict[str, Any]: self._id += 1 - envelope = { + return { "jsonrpc": "2.0", "id": self._id, "method": method, "params": params, } - raw = await self._transport.post_json("", json=envelope) - if not isinstance(raw, dict): - raise UnexpectedResponseError(f"JSON-RPC {method} returned a non-object response") - response = cast(dict[str, Any], raw) - if "error" in response: - err: Any = response["error"] - code, message, data = _extract_error_fields(err) - raise JsonRpcCallError(method=method, code=code, message=message, data=data) - return response.get("result") async def eth_chain_id(self) -> int: result = await self._call("eth_chainId", []) @@ -88,9 +87,44 @@ async def eth_chain_id(self) -> int: async def eth_call(self, *, to: str, data: str, block: str = "latest") -> str: result = await self._call("eth_call", [{"to": to, "data": data}, block]) - if not isinstance(result, str) or not _is_rpc_hex_string(result): - raise UnexpectedResponseError("eth_call did not return a hex string") - return result + return _parse_eth_call_result(result) + + async def eth_call_batch( + self, requests: Sequence[EthCallBatchRequest], *, block: str = "latest" + ) -> list[str]: + if not requests: + return [] + return await self._eth_call_batch_with_split(requests, block=block) + + async def _eth_call_batch_with_split( + self, requests: Sequence[EthCallBatchRequest], *, block: str + ) -> list[str]: + if len(requests) == 1: + to, data = requests[0] + return [await self.eth_call(to=to, data=data, block=block)] + + try: + return await self._post_eth_call_batch(requests, block=block) + except RequestRejectedError as error: + if error.status < 500: + raise + + midpoint = (len(requests) + 1) // 2 + left, right = await asyncio.gather( + self._eth_call_batch_with_split(requests[:midpoint], block=block), + self._eth_call_batch_with_split(requests[midpoint:], block=block), + ) + return [*left, *right] + + async def _post_eth_call_batch( + self, requests: Sequence[EthCallBatchRequest], *, block: str + ) -> list[str]: + envelopes = [ + self._build_envelope("eth_call", [{"to": to, "data": data}, block]) + for to, data in requests + ] + raw = await self._transport.post_json("", json=envelopes) + return _parse_eth_call_batch_response(raw, envelopes) async def eth_get_transaction_count(self, address: str, block: str = "pending") -> int: result = await self._call("eth_getTransactionCount", [address, block]) @@ -145,22 +179,18 @@ def verify_chain_id(self, expected: int) -> None: self._verified_chain_id = actual def _call(self, method: str, params: list[Any]) -> Any: + envelope = self._build_envelope(method, params) + raw = self._transport.post_json("", json=envelope) + return _parse_rpc_response(method, raw) + + def _build_envelope(self, method: str, params: list[Any]) -> dict[str, Any]: self._id += 1 - envelope = { + return { "jsonrpc": "2.0", "id": self._id, "method": method, "params": params, } - raw = self._transport.post_json("", json=envelope) - if not isinstance(raw, dict): - raise UnexpectedResponseError(f"JSON-RPC {method} returned a non-object response") - response = cast(dict[str, Any], raw) - if "error" in response: - err: Any = response["error"] - code, message, data = _extract_error_fields(err) - raise JsonRpcCallError(method=method, code=code, message=message, data=data) - return response.get("result") def eth_chain_id(self) -> int: result = self._call("eth_chainId", []) @@ -168,9 +198,42 @@ def eth_chain_id(self) -> int: def eth_call(self, *, to: str, data: str, block: str = "latest") -> str: result = self._call("eth_call", [{"to": to, "data": data}, block]) - if not isinstance(result, str) or not _is_rpc_hex_string(result): - raise UnexpectedResponseError("eth_call did not return a hex string") - return result + return _parse_eth_call_result(result) + + def eth_call_batch( + self, requests: Sequence[EthCallBatchRequest], *, block: str = "latest" + ) -> list[str]: + if not requests: + return [] + return self._eth_call_batch_with_split(requests, block=block) + + def _eth_call_batch_with_split( + self, requests: Sequence[EthCallBatchRequest], *, block: str + ) -> list[str]: + if len(requests) == 1: + to, data = requests[0] + return [self.eth_call(to=to, data=data, block=block)] + + try: + return self._post_eth_call_batch(requests, block=block) + except RequestRejectedError as error: + if error.status < 500: + raise + + midpoint = (len(requests) + 1) // 2 + left = self._eth_call_batch_with_split(requests[:midpoint], block=block) + right = self._eth_call_batch_with_split(requests[midpoint:], block=block) + return [*left, *right] + + def _post_eth_call_batch( + self, requests: Sequence[EthCallBatchRequest], *, block: str + ) -> list[str]: + envelopes = [ + self._build_envelope("eth_call", [{"to": to, "data": data}, block]) + for to, data in requests + ] + raw = self._transport.post_json("", json=envelopes) + return _parse_eth_call_batch_response(raw, envelopes) def eth_get_transaction_count(self, address: str, block: str = "pending") -> int: result = self._call("eth_getTransactionCount", [address, block]) @@ -210,6 +273,53 @@ def _extract_error_fields(err: object) -> tuple[int, str, object]: return 0, str(err), None +def _parse_rpc_response(method: str, raw: object) -> Any: + if not isinstance(raw, dict): + raise UnexpectedResponseError(f"JSON-RPC {method} returned a non-object response") + response = cast(dict[str, Any], raw) + if "error" in response: + err: Any = response["error"] + code, message, data = _extract_error_fields(err) + raise JsonRpcCallError(method=method, code=code, message=message, data=data) + return response.get("result") + + +def _parse_eth_call_batch_response(raw: object, envelopes: Sequence[dict[str, Any]]) -> list[str]: + if not isinstance(raw, list): + raise UnexpectedResponseError("JSON-RPC eth_call batch returned a non-array response") + + responses_by_id: dict[int, dict[str, Any]] = {} + for item in cast(list[object], raw): + if not isinstance(item, dict): + raise UnexpectedResponseError("JSON-RPC eth_call batch returned a non-object item") + response = cast(dict[str, Any], item) + raw_id = response.get("id") + if not isinstance(raw_id, int) or isinstance(raw_id, bool): + raise UnexpectedResponseError( + "JSON-RPC eth_call batch response is missing a numeric id" + ) + if raw_id in responses_by_id: + raise UnexpectedResponseError("JSON-RPC eth_call batch returned a duplicate id") + responses_by_id[raw_id] = response + + results: list[str] = [] + for envelope in envelopes: + raw_id = envelope["id"] + if not isinstance(raw_id, int) or isinstance(raw_id, bool): + raise RuntimeError("JSON-RPC request id must be an integer") + response = responses_by_id.get(raw_id) + if response is None: + raise UnexpectedResponseError("JSON-RPC eth_call batch response is missing an id") + results.append(_parse_eth_call_result(_parse_rpc_response("eth_call", response))) + return results + + +def _parse_eth_call_result(result: Any) -> str: + if not isinstance(result, str) or not _is_rpc_hex_string(result): + raise UnexpectedResponseError("eth_call did not return a hex string") + return result + + def _is_rpc_hex_string(value: str) -> bool: if not value.startswith("0x"): return False @@ -227,6 +337,7 @@ def _hex_to_int(value: Any, method: str) -> int: __all__ = [ + "EthCallBatchRequest", "JsonRpcCallError", "JsonRpcClient", "SyncJsonRpcClient", diff --git a/tests/unit/_relayer_helpers.py b/tests/unit/_relayer_helpers.py index 54a332f..c368b18 100644 --- a/tests/unit/_relayer_helpers.py +++ b/tests/unit/_relayer_helpers.py @@ -3,7 +3,7 @@ import dataclasses import json -from collections.abc import Callable +from collections.abc import Callable, Iterator from typing import Any, cast from urllib.parse import urlparse @@ -108,7 +108,7 @@ def make_rpc_handler( receipt_responses: list[dict[str, object] | None] | None = None, chain_id: int = 137, ) -> Callable[[httpx.Request], httpx.Response]: - captured: list[dict[str, object]] = [] + captured: list[object] = [] receipt_iter = iter(receipt_responses or []) def handler(request: httpx.Request) -> httpx.Response: @@ -157,54 +157,108 @@ def trading_approval_rpc_handler( receipt_responses: list[dict[str, object] | None] | None = None, chain_id: int = 137, ) -> Callable[[httpx.Request], httpx.Response]: - captured: list[dict[str, object]] = [] + captured: list[object] = [] receipt_iter = iter(receipt_responses or []) - allowance_selector = "0x" + keccak(b"allowance(address,address)")[:4].hex() - approved_selector = "0x" + keccak(b"isApprovedForAll(address,address)")[:4].hex() def handler(request: httpx.Request) -> httpx.Response: - body = json.loads(request.content.decode("utf-8")) + body = cast(object, json.loads(request.content.decode("utf-8"))) captured.append(body) - method = body["method"] - if method == "eth_call": - data = body["params"][0]["data"] - if data.startswith(allowance_selector): - result: object = "0x" + hex(allowance)[2:].rjust(64, "0") - elif data.startswith(approved_selector): - result = "0x" + ("1" if approved else "0").rjust(64, "0") - else: - result = "0x" + "0" * 64 - elif method == "eth_chainId": - result = hex(chain_id) - elif method == "eth_getTransactionCount": - result = hex(nonce) - elif method == "eth_gasPrice": - result = hex(gas_price) - elif method == "eth_estimateGas": - result = hex(gas_estimate) - elif method == "eth_sendRawTransaction": - result = send_response or ("0x" + "ab" * 32) - elif method == "eth_getTransactionReceipt": - try: - result = next(receipt_iter) - except StopIteration: - result = None - else: + + if isinstance(body, list): return httpx.Response( 200, - json={"jsonrpc": "2.0", "id": body["id"], "error": {"message": "unmocked"}}, + json=[ + _trading_approval_rpc_response( + item, + allowance=allowance, + approved=approved, + nonce=nonce, + gas_price=gas_price, + gas_estimate=gas_estimate, + send_response=send_response, + receipt_iter=receipt_iter, + chain_id=chain_id, + ) + for item in cast(list[object], body) + ], request=request, ) - return httpx.Response( - 200, - json={"jsonrpc": "2.0", "id": body["id"], "result": result}, - request=request, + + if not isinstance(body, dict): + return httpx.Response(400, request=request) + + response = _trading_approval_rpc_response( + cast(dict[str, object], body), + allowance=allowance, + approved=approved, + nonce=nonce, + gas_price=gas_price, + gas_estimate=gas_estimate, + send_response=send_response, + receipt_iter=receipt_iter, + chain_id=chain_id, ) + if "error" in response: + return httpx.Response(200, json=response, request=request) + return httpx.Response(200, json=response, request=request) handler.captured = captured # type: ignore[attr-defined] return handler +def _trading_approval_rpc_response( + body: object, + *, + allowance: int, + approved: bool, + nonce: int, + gas_price: int, + gas_estimate: int, + send_response: str | None, + receipt_iter: Iterator[dict[str, object] | None], + chain_id: int, +) -> dict[str, object]: + allowance_selector = "0x" + keccak(b"allowance(address,address)")[:4].hex() + approved_selector = "0x" + keccak(b"isApprovedForAll(address,address)")[:4].hex() + + if not isinstance(body, dict): + return {"jsonrpc": "2.0", "id": None, "error": {"message": "malformed"}} + + request = cast(dict[str, Any], body) + method = request["method"] + if method == "eth_call": + params = cast(list[dict[str, Any]], request["params"]) + data = params[0]["data"] + if data.startswith(allowance_selector): + result: object = "0x" + hex(allowance)[2:].rjust(64, "0") + elif data.startswith(approved_selector): + result = "0x" + ("1" if approved else "0").rjust(64, "0") + else: + result = "0x" + "0" * 64 + elif method == "eth_chainId": + result = hex(chain_id) + elif method == "eth_getTransactionCount": + result = hex(nonce) + elif method == "eth_gasPrice": + result = hex(gas_price) + elif method == "eth_estimateGas": + result = hex(gas_estimate) + elif method == "eth_sendRawTransaction": + result = send_response or ("0x" + "ab" * 32) + elif method == "eth_getTransactionReceipt": + try: + result = next(receipt_iter) + except StopIteration: + result = None + else: + return { + "jsonrpc": "2.0", + "id": request["id"], + "error": {"message": "unmocked"}, + } + return {"jsonrpc": "2.0", "id": request["id"], "result": result} + + async def make_safe_client() -> AsyncSecureClient: from eth_account import Account diff --git a/tests/unit/test_eoa_broadcast.py b/tests/unit/test_eoa_broadcast.py index e9173ad..d7c1d65 100644 --- a/tests/unit/test_eoa_broadcast.py +++ b/tests/unit/test_eoa_broadcast.py @@ -1,6 +1,7 @@ # pyright: reportPrivateUsage=false import asyncio import dataclasses +from typing import cast import httpx import pytest @@ -8,7 +9,6 @@ make_eoa_client_with_rpc, make_rpc_handler, ) -from eth_utils.crypto import keccak from polymarket import TransactionCall from polymarket.errors import TimeoutError, TransactionFailedError, UserInputError @@ -140,24 +140,33 @@ def test_eoa_setup_trading_approvals_submits_and_waits_for_required_calls_sequen send_iter = iter(send_hashes) receipts: list[dict[str, object] | None] = [{"status": "0x1"} for _ in range(16)] receipt_iter = iter(receipts) - calls: list[dict[str, object]] = [] + calls: list[object] = [] def handler(request: httpx.Request) -> httpx.Response: import json - body = json.loads(request.content.decode("utf-8")) + body = cast(object, json.loads(request.content.decode("utf-8"))) calls.append(body) + if isinstance(body, list): + return httpx.Response( + 200, + json=[ + { + "jsonrpc": "2.0", + "id": item["id"], + "result": _approval_check_result(item), + } + for item in cast(list[dict[str, object]], body) + ], + request=request, + ) + + body = cast(dict[str, object], body) method = body["method"] if method == "eth_chainId": result: object = hex(137) elif method == "eth_call": - data = body["params"][0]["data"] - allowance_selector = "0x" + keccak(b"allowance(address,address)")[:4].hex() - approved_selector = "0x" + keccak(b"isApprovedForAll(address,address)")[:4].hex() - if data.startswith(allowance_selector) or data.startswith(approved_selector): - result = "0x" + "0" * 64 - else: - result = "0x" + "0" * 64 + result = _approval_check_result(body) elif method == "eth_getTransactionCount": result = hex(7) elif method == "eth_gasPrice": @@ -179,6 +188,9 @@ def handler(request: httpx.Request) -> httpx.Response: 200, json={"jsonrpc": "2.0", "id": body["id"], "result": result}, request=request ) + def _approval_check_result(body: dict[str, object]) -> str: + return "0x" + "0" * 64 + async def run() -> object: client = await make_eoa_client_with_rpc(rpc_handler=handler) client._ctx = dataclasses.replace( @@ -193,10 +205,19 @@ async def run() -> object: handle = asyncio.run(run()) assert isinstance(handle, DeprecatedTransactionHandle) assert handle.transaction_hash is None - send_methods = [c for c in calls if c["method"] == "eth_sendRawTransaction"] - assert len(send_methods) == 16 - receipt_methods = [c for c in calls if c["method"] == "eth_getTransactionReceipt"] - assert len(receipt_methods) >= 16 + assert _rpc_method_count(calls, "eth_sendRawTransaction") == 16 + assert _rpc_method_count(calls, "eth_getTransactionReceipt") >= 16 + + +def _rpc_method_count(calls: list[object], method: str) -> int: + total = 0 + for call in calls: + if not isinstance(call, dict): + continue + rpc_call = cast(dict[str, object], call) + if rpc_call.get("method") == method: + total += 1 + return total def test_rpc_client_closes_with_client() -> None: diff --git a/tests/unit/test_eoa_rpc.py b/tests/unit/test_eoa_rpc.py new file mode 100644 index 0000000..25ba7b4 --- /dev/null +++ b/tests/unit/test_eoa_rpc.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import asyncio +import json + +import httpx + +from polymarket._internal.eoa.rpc import JsonRpcClient, SyncJsonRpcClient +from polymarket.clients._transport import AsyncTransport, SyncTransport + + +def test_async_eth_call_batch_splits_rejected_batches_preserving_order() -> None: + bodies: list[object] = [] + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content.decode("utf-8")) + bodies.append(body) + if isinstance(body, list): + return httpx.Response(500, request=request) + return httpx.Response( + 200, + json={ + "jsonrpc": "2.0", + "id": body["id"], + "result": body["params"][0]["data"], + }, + request=request, + ) + + async def run() -> list[str]: + http_client = httpx.AsyncClient( + base_url="https://rpc.test", transport=httpx.MockTransport(handler) + ) + client = JsonRpcClient(AsyncTransport(base_url="https://rpc.test", client=http_client)) + try: + return await client.eth_call_batch(_batch_requests()) + finally: + await http_client.aclose() + + assert asyncio.run(run()) == ["0x11111111", "0x22222222", "0x33333333", "0x44444444"] + assert len(bodies) == 7 + + +def test_sync_eth_call_batch_splits_rejected_batches_preserving_order() -> None: + bodies: list[object] = [] + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content.decode("utf-8")) + bodies.append(body) + if isinstance(body, list): + return httpx.Response(500, request=request) + return httpx.Response( + 200, + json={ + "jsonrpc": "2.0", + "id": body["id"], + "result": body["params"][0]["data"], + }, + request=request, + ) + + http_client = httpx.Client(base_url="https://rpc.test", transport=httpx.MockTransport(handler)) + client = SyncJsonRpcClient(SyncTransport(base_url="https://rpc.test", client=http_client)) + try: + result = client.eth_call_batch(_batch_requests()) + finally: + http_client.close() + + assert result == ["0x11111111", "0x22222222", "0x33333333", "0x44444444"] + assert len(bodies) == 7 + + +def _batch_requests() -> list[tuple[str, str]]: + to = "0x0000000000000000000000000000000000000001" + return [ + (to, "0x11111111"), + (to, "0x22222222"), + (to, "0x33333333"), + (to, "0x44444444"), + ]