diff --git a/src/error/ErrDecimalFloat.sol b/src/error/ErrDecimalFloat.sol index 60f814b9..eef6de0a 100644 --- a/src/error/ErrDecimalFloat.sol +++ b/src/error/ErrDecimalFloat.sol @@ -2,6 +2,8 @@ // SPDX-FileCopyrightText: Copyright (c) 2020 thedavidmeister pragma solidity ^0.8.25; +import {Float} from "../lib/LibDecimalFloat.sol"; + /// @dev Thrown when a coefficient overflows. error CoefficientOverflow(int256 signedCoefficient, int256 exponent); @@ -25,3 +27,6 @@ error LossyConversionToFloat(int256 signedCoefficient, int256 exponent); /// @dev Thrown when converting a float to some value when the conversion /// is lossy. error LossyConversionFromFloat(int256 signedCoefficient, int256 exponent); + +/// @dev Thrown when attempting to exponentiate 0^b where b is negative. +error ZeroNegativePower(Float b); diff --git a/src/lib/LibDecimalFloat.sol b/src/lib/LibDecimalFloat.sol index 7fb83e8a..812655f8 100644 --- a/src/lib/LibDecimalFloat.sol +++ b/src/lib/LibDecimalFloat.sol @@ -14,7 +14,8 @@ import { Log10Zero, NegativeFixedDecimalConversion, LossyConversionFromFloat, - LossyConversionToFloat + LossyConversionToFloat, + ZeroNegativePower } from "../error/ErrDecimalFloat.sol"; import { LibDecimalFloatImplementation, @@ -89,6 +90,10 @@ library LibDecimalFloat { address constant LOG_TABLES_ADDRESS = 0x295180b25A5059a2e7eC64272ba4F85047B4146A; + Float constant FLOAT_ZERO = Float.wrap(0); + + Float constant FLOAT_ONE = Float.wrap(bytes32(uint256(1))); + /// type(int224).max, type(int32).max Float constant FLOAT_MAX_VALUE = Float.wrap(bytes32(uint256(0x7fffffff7fffffffffffffffffffffffffffffffffffffffffffffffffffffff))); @@ -640,6 +645,18 @@ library LibDecimalFloat { /// logarithm tables. function pow(Float a, Float b, address tablesDataContract) internal view returns (Float) { (int256 signedCoefficientA, int256 exponentA) = a.unpack(); + if (b.isZero()) { + return FLOAT_ONE; + } else if (signedCoefficientA == 0) { + if (b.lt(FLOAT_ZERO)) { + // If b is negative, and a is 0, so we revert. + revert ZeroNegativePower(b); + } + + // If a is zero, then a^b is always zero, regardless of b. + // This is a special case because log10(0) is undefined. + return FLOAT_ZERO; + } (int256 signedCoefficientC, int256 exponentC) = LibDecimalFloatImplementation.log10(tablesDataContract, signedCoefficientA, exponentA); diff --git a/test/src/lib/LibDecimalFloat.constants.t.sol b/test/src/lib/LibDecimalFloat.constants.t.sol index 0813d690..cff9c063 100644 --- a/test/src/lib/LibDecimalFloat.constants.t.sol +++ b/test/src/lib/LibDecimalFloat.constants.t.sol @@ -26,4 +26,16 @@ contract LibDecimalFloatConstantsTest is Test { ); assertEq(Float.unwrap(e), Float.unwrap(expected)); } + + function testFloatZero() external pure { + Float zero = LibDecimalFloat.FLOAT_ZERO; + Float expected = LibDecimalFloat.packLossless(0, 0); + assertEq(Float.unwrap(zero), Float.unwrap(expected)); + } + + function testFloatOne() external pure { + Float one = LibDecimalFloat.FLOAT_ONE; + Float expected = LibDecimalFloat.packLossless(1, 0); + assertEq(Float.unwrap(one), Float.unwrap(expected)); + } } diff --git a/test/src/lib/LibDecimalFloat.pow.t.sol b/test/src/lib/LibDecimalFloat.pow.t.sol index a4fed971..76da2cd5 100644 --- a/test/src/lib/LibDecimalFloat.pow.t.sol +++ b/test/src/lib/LibDecimalFloat.pow.t.sol @@ -4,7 +4,7 @@ pragma solidity =0.8.25; import {LogTest} from "../../abstract/LogTest.sol"; import {LibDecimalFloat, Float} from "src/lib/LibDecimalFloat.sol"; - +import {ZeroNegativePower} from "src/error/ErrDecimalFloat.sol"; import {console2} from "forge-std/Test.sol"; contract LibDecimalFloatPowTest is LogTest { @@ -44,6 +44,33 @@ contract LibDecimalFloatPowTest is LogTest { checkPow(1785215562, 0, 18, 0, 3388, 163); } + /// a^0 = 1 for all a including 0^0. + function testPowBZero(Float a, int32 exponentB) external { + Float b = LibDecimalFloat.packLossless(0, exponentB); + // If b is zero then the result is always 1. + address tables = logTables(); + Float c = a.pow(b, tables); + assertTrue(c.eq(LibDecimalFloat.packLossless(1, 0)), "c is not 1"); + } + + /// 0^b is defined as 0 for all b > 0. + function testPowAZero(int32 exponentA, Float b) external { + // 0^0 is defined as 1. + vm.assume(b.gt(LibDecimalFloat.FLOAT_ZERO)); + // If a is zero then the result is always zero. + Float a = LibDecimalFloat.packLossless(0, exponentA); + address tables = logTables(); + Float c = a.pow(b, tables); + assertTrue(c.isZero(), "c is not zero"); + } + + /// 0^a is error for all a < 0. + function testPowAZeroNegative(Float b) external { + vm.assume(b.lt(LibDecimalFloat.FLOAT_ZERO)); + vm.expectRevert(abi.encodeWithSelector(ZeroNegativePower.selector, b)); + this.powExternal(LibDecimalFloat.FLOAT_ZERO, b); + } + function checkRoundTrip(int256 signedCoefficientA, int256 exponentA, int256 signedCoefficientB, int256 exponentB) internal {