diff --git a/crates/float/src/lib.rs b/crates/float/src/lib.rs index 208f0fe2..1d1dc972 100644 --- a/crates/float/src/lib.rs +++ b/crates/float/src/lib.rs @@ -12,7 +12,7 @@ use revm::interpreter::interpreter::EthInterpreter; use revm::primitives::{address, fixed_bytes}; use revm::{Context, MainBuilder, MainContext, SystemCallEvm}; use std::cell::RefCell; -use std::ops::{Add, Sub}; +use std::ops::{Add, Div, Mul, Sub}; use std::thread::AccessError; use thiserror::Error; @@ -188,6 +188,7 @@ impl Float { }) } + // NOTE: LibFormatDecimalFloat.toDecimalString currently uses 18 decimal places pub fn format(self) -> Result { let Float(a) = self; let calldata = DecimalFloat::formatCall { a }.abi_encode(); @@ -262,9 +263,40 @@ impl Sub for Float { } } +impl Mul for Float { + type Output = Result; + + fn mul(self, b: Self) -> Self::Output { + let Float(a) = self; + let Float(b) = b; + let calldata = DecimalFloat::mulCall { a, b }.abi_encode(); + + execute_call(Bytes::from(calldata), |output| { + let decoded = DecimalFloat::mulCall::abi_decode_returns(output.as_ref())?; + Ok(Float(decoded)) + }) + } +} + +impl Div for Float { + type Output = Result; + + fn div(self, b: Self) -> Self::Output { + let Float(a) = self; + let Float(b) = b; + let calldata = DecimalFloat::divCall { a, b }.abi_encode(); + + execute_call(Bytes::from(calldata), |output| { + let decoded = DecimalFloat::divCall::abi_decode_returns(output.as_ref())?; + Ok(Float(decoded)) + }) + } +} + #[cfg(test)] mod tests { use super::*; + use core::str::FromStr; use proptest::prelude::*; prop_compose! { @@ -294,8 +326,6 @@ mod tests { #[test] fn test_parse_and_format() { let float = Float::parse("1.1341234234625468391".to_string()).unwrap(); - // NOTE: LibFormatDecimalFloat.toDecimalString currently uses 18 decimal places - // TODO: make this fail on a separate PR let err = float.format().unwrap_err(); assert!(matches!( @@ -304,6 +334,23 @@ mod tests { )); } + #[test] + fn test_parse_empty_string_error() { + let err = Float::parse("".to_string()).unwrap_err(); + // We don't know the exact selector here, just ensure the error path is hit. + assert!(matches!(err, FloatError::DecimalFloatSelector(_))); + } + + #[test] + fn test_parse_exponent_overflow_error() { + // Extremely large exponent expected to overflow (exponent >> i32::MAX). + let err = Float::parse("1e3000000000".to_string()).unwrap_err(); + assert!(matches!( + err, + FloatError::DecimalFloat(DecimalFloatErrors::ExponentOverflow(_)) + )); + } + #[test] fn test_parse_edge_cases() { let err = Float::parse("1.2.3".to_string()).unwrap_err(); @@ -330,6 +377,39 @@ mod tests { } } + #[test] + fn test_add_exponent_overflow_error() { + let max_coeff_str = "13479973333575319897333507543509815336818572211270286240551805124607"; + let large_coeff_i224 = I224::from_str(max_coeff_str).unwrap(); + let exponent_max = i32::MAX; + + let a = Float::pack_lossless(large_coeff_i224, exponent_max).unwrap(); + + let err = (a + a).unwrap_err(); + + assert!(matches!( + err, + FloatError::DecimalFloat(DecimalFloatErrors::ExponentOverflow(_)) + )); + } + + #[test] + fn test_sub_exponent_overflow_error() { + let max_coeff_str = "13479973333575319897333507543509815336818572211270286240551805124607"; + let large_coeff_i224 = I224::from_str(max_coeff_str).unwrap(); + let exponent_max = i32::MAX; + + let a = Float::pack_lossless(large_coeff_i224, exponent_max).unwrap(); + let b = Float::pack_lossless(-large_coeff_i224, exponent_max).unwrap(); + + let err = (b - a).unwrap_err(); + + assert!(matches!( + err, + FloatError::DecimalFloat(DecimalFloatErrors::ExponentOverflow(_)) + )); + } + proptest! { #[test] fn test_add(a in reasonable_float(), b in reasonable_float()) { @@ -415,4 +495,101 @@ mod tests { prop_assert!(!(lt && gt), "both less than and greater than: a: {a_str}, b: {b_str}"); } } + + proptest! { + #[test] + fn test_mul(a in reasonable_float(), b in reasonable_float()) { + (a * b).unwrap(); + } + } + + proptest! { + #[test] + fn test_div(a in reasonable_float(), b in reasonable_float()) { + let zero = Float::parse("0".to_string()).unwrap(); + prop_assume!(!b.eq(zero).unwrap()); + + (a / b).unwrap(); + } + } + + prop_compose! { + fn small_int_float()(int_part in -1_000_000_000_000i128..1_000_000_000_000i128) -> Float { + Float::parse(int_part.to_string()).unwrap() + } + } + + proptest! { + #[test] + fn test_mul_div_int(a in small_int_float(), b in small_int_float()) { + let zero = Float::parse("0".to_string()).unwrap(); + prop_assume!(!b.eq(zero).unwrap()); + + let product = (a * b).unwrap(); + let quotient = (product / b).unwrap(); + + prop_assert!( + a.eq(quotient).unwrap(), + "a: {}, quotient: {}, b: {}", + a.show_unpacked().unwrap(), + quotient.show_unpacked().unwrap(), + b.show_unpacked().unwrap() + ); + } + } + + #[test] + fn test_mul_div_manual() { + let two = Float::parse("2".to_string()).unwrap(); + let three = Float::parse("3".to_string()).unwrap(); + let six = Float::parse("6".to_string()).unwrap(); + + assert!(two.eq((six / three).unwrap()).unwrap()); + assert!(six.eq((two * three).unwrap()).unwrap()); + } + + #[test] + fn test_divide_by_zero_error() { + let one = Float::parse("1".to_string()).unwrap(); + let zero = Float::parse("0".to_string()).unwrap(); + let err = (one / zero).unwrap_err(); + + assert!(matches!(err, FloatError::Revert(_))); + } + + #[test] + fn test_mul_exponent_overflow_error() { + let near_max_exp = Float::parse("1e2147483646".to_string()).unwrap(); + let one_e_two = Float::parse("1e2".to_string()).unwrap(); + + let err = (near_max_exp * one_e_two).unwrap_err(); + assert!(matches!( + err, + FloatError::DecimalFloat(DecimalFloatErrors::ExponentOverflow(_)) + )); + } + + #[test] + fn test_div_exponent_overflow_error() { + let near_max_exp = Float::parse("1e2147483646".to_string()).unwrap(); + let one_e_neg_hundred = Float::parse("1e-100".to_string()).unwrap(); + + let err = (near_max_exp / one_e_neg_hundred).unwrap_err(); + assert!(matches!( + err, + FloatError::DecimalFloat(DecimalFloatErrors::ExponentOverflow(_)) + )); + } + + #[test] + fn test_mul_exponent_underflow_error() { + let near_min_exp = Float::parse("1e-2147483646".to_string()).unwrap(); + let one_e_neg_three = Float::parse("1e-3".to_string()).unwrap(); + + let err = (near_min_exp * one_e_neg_three).unwrap_err(); + assert!(matches!( + err, + FloatError::DecimalFloat(DecimalFloatErrors::ExponentOverflow(_)) + )); + } }