diff --git a/tests/core/types/test_currency_types.py b/tests/core/types/test_currency_types.py new file mode 100644 index 0000000000..110bdef493 --- /dev/null +++ b/tests/core/types/test_currency_types.py @@ -0,0 +1,37 @@ +from web3.types import ( + Gwei, + Wei, +) + + +def test_wei_arithmetic_preserves_type() -> None: + value = Wei(5) + + value += Wei(1) + + assert value == 6 + assert isinstance(value, Wei) + assert isinstance(value + 1, Wei) + assert isinstance(1 + value, Wei) + assert isinstance(value - 1, Wei) + assert isinstance(10 - value, Wei) + assert isinstance(value * 2, Wei) + assert isinstance(2 * value, Wei) + assert isinstance(value // 2, Wei) + assert isinstance(value % 4, Wei) + assert isinstance(-value, Wei) + assert isinstance(+value, Wei) + assert isinstance(abs(value), Wei) + quotient, remainder = divmod(value, 4) + assert isinstance(quotient, Wei) + assert isinstance(remainder, Wei) + + +def test_gwei_arithmetic_preserves_type() -> None: + value = Gwei(5) + + value -= Gwei(1) + + assert value == 4 + assert isinstance(value, Gwei) + assert isinstance(value + 1, Gwei) diff --git a/web3/types.py b/web3/types.py index c3ae0b4834..a7d3ebf50b 100644 --- a/web3/types.py +++ b/web3/types.py @@ -32,6 +32,7 @@ ) from web3._utils.compat import ( NotRequired, + Self, ) if TYPE_CHECKING: @@ -76,8 +77,51 @@ Nonce = NewType("Nonce", int) RPCEndpoint = NewType("RPCEndpoint", str) Timestamp = NewType("Timestamp", int) -Wei = NewType("Wei", int) -Gwei = NewType("Gwei", int) +class _IntegerType(int): + def __add__(self, other: int) -> Self: + return self.__class__(int(self) + other) + + def __radd__(self, other: int) -> Self: + return self.__class__(other + int(self)) + + def __sub__(self, other: int) -> Self: + return self.__class__(int(self) - other) + + def __rsub__(self, other: int) -> Self: + return self.__class__(other - int(self)) + + def __mul__(self, other: int) -> Self: + return self.__class__(int(self) * other) + + def __rmul__(self, other: int) -> Self: + return self.__class__(other * int(self)) + + def __floordiv__(self, other: int) -> Self: + return self.__class__(int(self) // other) + + def __mod__(self, other: int) -> Self: + return self.__class__(int(self) % other) + + def __divmod__(self, other: int) -> tuple[Self, Self]: + quotient, remainder = divmod(int(self), other) + return self.__class__(quotient), self.__class__(remainder) + + def __neg__(self) -> Self: + return self.__class__(-int(self)) + + def __pos__(self) -> Self: + return self.__class__(+int(self)) + + def __abs__(self) -> Self: + return self.__class__(abs(int(self))) + + +class Wei(_IntegerType): + pass + + +class Gwei(_IntegerType): + pass Formatters = dict[RPCEndpoint, Callable[..., Any]]