diff --git a/bbq/vm/builtin_globals.go b/bbq/vm/builtin_globals.go index c6765e333..312506191 100644 --- a/bbq/vm/builtin_globals.go +++ b/bbq/vm/builtin_globals.go @@ -303,6 +303,8 @@ func init() { registerBuiltinSaturatingArithmeticFunctions() registerBuiltinFixedPointPowFunctions() + + registerBuiltinFixedPointMultiplyDivideFunctions() } func registerBuiltinCommonTypeBoundFunctions() { @@ -444,6 +446,19 @@ func registerBuiltinFixedPointPowFunctions() { } } +func registerBuiltinFixedPointMultiplyDivideFunctions() { + for baseType, funcType := range sema.FixedPointMultiplyDivideFunctionTypes { //nolint:maprange + registerBuiltinTypeBoundFunction( + commons.TypeQualifier(baseType), + NewNativeFunctionValue( + sema.FixedPointNumericTypeMultiplyDivideFunctionName, + funcType, + interpreter.NativeFixedPointMultiplyDivideFunction, + ), + ) + } +} + func newFromStringFunction(typedParser interpreter.TypedStringValueParser) *NativeFunctionValue { functionType := sema.FromStringFunctionType(typedParser.ReceiverType) parser := typedParser.Parser diff --git a/interpreter/fixedpoint_test.go b/interpreter/fixedpoint_test.go index 8ef7deb9f..89dc3e555 100644 --- a/interpreter/fixedpoint_test.go +++ b/interpreter/fixedpoint_test.go @@ -23,6 +23,7 @@ import ( "math" "math/big" "strconv" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -40,6 +41,358 @@ import ( . "github.com/onflow/cadence/test_utils/sema_utils" ) +func fix128BigInt(s string) *big.Int { + // Parse a decimal string like "33.333333333333333333333333" + // into the raw scaled big.Int (removing the decimal point). + // The fractional part must have exactly Fix128Scale (24) digits. + parts := strings.SplitN(s, ".", 2) + if len(parts) == 1 { + // No decimal point — treat as integer, scale up + v, ok := new(big.Int).SetString(s, 10) + if !ok { + panic("invalid fix128 string: " + s) + } + return v.Mul(v, sema.Fix128FactorIntBig) + } + if len(parts[1]) != sema.Fix128Scale { + panic(fmt.Sprintf("expected %d fractional digits, got %d: %s", sema.Fix128Scale, len(parts[1]), s)) + } + raw := parts[0] + parts[1] + v, ok := new(big.Int).SetString(raw, 10) + if !ok { + panic("invalid fix128 string: " + s) + } + return v +} + +func TestInterpretFixedPointMultiplyDivide(t *testing.T) { + + t.Parallel() + + t.Run("UFix64", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + a, b, c string + rounding string + expected uint64 + expectedError bool + } + + // Expected values pre-computed using the fixed-point library's UFix64.FMD(). + testCases := []testCase{ + // Basic: 2*3/1 = 6 + {a: "2.00000000", b: "3.00000000", c: "1.00000000", rounding: "towardZero", expected: 600000000}, + // Rounding modes: 10*10/3 + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "towardZero", expected: 3333333333}, + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "awayFromZero", expected: 3333333334}, + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "nearestHalfAway", expected: 3333333333}, + // Fractional: 0.5*0.5/1.0 = 0.25 + {a: "0.50000000", b: "0.50000000", c: "1.00000000", rounding: "towardZero", expected: 25000000}, + // Zero factor + {a: "0.00000000", b: "5.00000000", c: "2.00000000", rounding: "towardZero", expected: 0}, + {a: "5.00000000", b: "0.00000000", c: "2.00000000", rounding: "towardZero", expected: 0}, + // Larger values: 100*200/50 = 400 + {a: "100.00000000", b: "200.00000000", c: "50.00000000", rounding: "towardZero", expected: 40000000000}, + // 1*1/3 with different rounding + {a: "1.00000000", b: "1.00000000", c: "3.00000000", rounding: "towardZero", expected: 33333333}, + {a: "1.00000000", b: "1.00000000", c: "3.00000000", rounding: "awayFromZero", expected: 33333334}, + // Division by zero + {a: "1.00000000", b: "2.00000000", c: "0.00000000", rounding: "towardZero", expectedError: true}, + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s * %s / %s (%s)", tc.a, tc.b, tc.c, tc.rounding) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): UFix64 { + let a: UFix64 = %s + let b: UFix64 = %s + let c: UFix64 = %s + return a.multiplyDivide(b, c, rounding: RoundingRule.%s) + } + `, + tc.a, tc.b, tc.c, tc.rounding, + ) + + inter := parseCheckAndPrepareWithRoundingRule(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUnmeteredUFix64Value(tc.expected) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("Fix64", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + a, b, c string + rounding string + expected int64 + expectedError bool + } + + testCases := []testCase{ + // Basic: 2*3/1 = 6 + {a: "2.00000000", b: "3.00000000", c: "1.00000000", rounding: "towardZero", expected: 600000000}, + // Signed: (-2)*3/1 = -6 + {a: "-2.00000000", b: "3.00000000", c: "1.00000000", rounding: "towardZero", expected: -600000000}, + // (-5)*(-3)/2 = 7.5 + {a: "-5.00000000", b: "-3.00000000", c: "2.00000000", rounding: "towardZero", expected: 750000000}, + // Rounding: 10*10/3 + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "towardZero", expected: 3333333333}, + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "awayFromZero", expected: 3333333334}, + // 1/3 truncated + {a: "1.00000000", b: "1.00000000", c: "3.00000000", rounding: "towardZero", expected: 33333333}, + // Division by zero + {a: "1.00000000", b: "2.00000000", c: "0.00000000", rounding: "towardZero", expectedError: true}, + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s * %s / %s (%s)", tc.a, tc.b, tc.c, tc.rounding) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): Fix64 { + let a: Fix64 = %s + let b: Fix64 = %s + let c: Fix64 = %s + return a.multiplyDivide(b, c, rounding: RoundingRule.%s) + } + `, + tc.a, tc.b, tc.c, tc.rounding, + ) + + inter := parseCheckAndPrepareWithRoundingRule(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUnmeteredFix64Value(tc.expected) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("UFix128", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + a, b, c string + rounding string + expected string + expectedError bool + } + + testCases := []testCase{ + {a: "2.000000000000000000000000", b: "3.000000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "6.000000000000000000000000"}, + {a: "10.000000000000000000000000", b: "10.000000000000000000000000", c: "3.000000000000000000000000", rounding: "towardZero", expected: "33.333333333333333333333333"}, + {a: "10.000000000000000000000000", b: "10.000000000000000000000000", c: "3.000000000000000000000000", rounding: "awayFromZero", expected: "33.333333333333333333333334"}, + {a: "0.500000000000000000000000", b: "0.500000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "0.250000000000000000000000"}, + {a: "100.000000000000000000000000", b: "200.000000000000000000000000", c: "50.000000000000000000000000", rounding: "towardZero", expected: "400.000000000000000000000000"}, + // Division by zero + {a: "1.000000000000000000000000", b: "2.000000000000000000000000", c: "0.000000000000000000000000", rounding: "towardZero", expectedError: true}, + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s * %s / %s (%s)", tc.a, tc.b, tc.c, tc.rounding) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): UFix128 { + let a: UFix128 = %s + let b: UFix128 = %s + let c: UFix128 = %s + return a.multiplyDivide(b, c, rounding: RoundingRule.%s) + } + `, + tc.a, tc.b, tc.c, tc.rounding, + ) + + inter := parseCheckAndPrepareWithRoundingRule(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUFix128ValueFromBigInt(nil, fix128BigInt(tc.expected)) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("Fix128", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + a, b, c string + rounding string + expected string + expectedError bool + } + + testCases := []testCase{ + {a: "2.000000000000000000000000", b: "3.000000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "6.000000000000000000000000"}, + {a: "-2.000000000000000000000000", b: "3.000000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "-6.000000000000000000000000"}, + {a: "-2.000000000000000000000000", b: "-3.000000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "6.000000000000000000000000"}, + {a: "10.000000000000000000000000", b: "10.000000000000000000000000", c: "3.000000000000000000000000", rounding: "towardZero", expected: "33.333333333333333333333333"}, + {a: "10.000000000000000000000000", b: "10.000000000000000000000000", c: "3.000000000000000000000000", rounding: "awayFromZero", expected: "33.333333333333333333333334"}, + // Division by zero + {a: "1.000000000000000000000000", b: "2.000000000000000000000000", c: "0.000000000000000000000000", rounding: "towardZero", expectedError: true}, + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s * %s / %s (%s)", tc.a, tc.b, tc.c, tc.rounding) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): Fix128 { + let a: Fix128 = %s + let b: Fix128 = %s + let c: Fix128 = %s + return a.multiplyDivide(b, c, rounding: RoundingRule.%s) + } + `, + tc.a, tc.b, tc.c, tc.rounding, + ) + + inter := parseCheckAndPrepareWithRoundingRule(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewFix128ValueFromBigInt(nil, fix128BigInt(tc.expected)) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("default rounding (truncate)", func(t *testing.T) { + + t.Parallel() + + t.Run("UFix64", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepareWithRoundingRule(t, ` + fun test(): UFix64 { + let a: UFix64 = 10.0 + let b: UFix64 = 10.0 + let c: UFix64 = 3.0 + return a.multiplyDivide(b, c) + } + `) + result, err := inter.Invoke("test") + require.NoError(t, err) + + // 10*10/3 truncated = 33.33333333 + expected := interpreter.NewUnmeteredUFix64Value(3333333333) + AssertValuesEqual(t, inter, expected, result) + }) + + t.Run("Fix64", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepareWithRoundingRule(t, ` + fun test(): Fix64 { + let a: Fix64 = 10.0 + let b: Fix64 = 10.0 + let c: Fix64 = 3.0 + return a.multiplyDivide(b, c) + } + `) + result, err := inter.Invoke("test") + require.NoError(t, err) + + // 10*10/3 truncated = 33.33333333 + expected := interpreter.NewUnmeteredFix64Value(3333333333) + AssertValuesEqual(t, inter, expected, result) + }) + + t.Run("UFix128", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepareWithRoundingRule(t, ` + fun test(): UFix128 { + let a: UFix128 = 10.0 + let b: UFix128 = 10.0 + let c: UFix128 = 3.0 + return a.multiplyDivide(b, c) + } + `) + result, err := inter.Invoke("test") + require.NoError(t, err) + + // 10*10/3 truncated = 33.333333333333333333333333 + expected := interpreter.NewUFix128ValueFromBigInt(nil, fix128BigInt("33.333333333333333333333333")) + AssertValuesEqual(t, inter, expected, result) + }) + + t.Run("Fix128", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepareWithRoundingRule(t, ` + fun test(): Fix128 { + let a: Fix128 = 10.0 + let b: Fix128 = 10.0 + let c: Fix128 = 3.0 + return a.multiplyDivide(b, c) + } + `) + result, err := inter.Invoke("test") + require.NoError(t, err) + + // 10*10/3 truncated = 33.333333333333333333333333 + expected := interpreter.NewFix128ValueFromBigInt(nil, fix128BigInt("33.333333333333333333333333")) + AssertValuesEqual(t, inter, expected, result) + }) + }) +} + func TestInterpretFixedPointPow(t *testing.T) { t.Parallel() @@ -205,11 +558,7 @@ func TestInterpretFixedPointPow(t *testing.T) { result, err := inter.Invoke("test") require.NoError(t, err) - expected := parseCheckAndPrepare(t, fmt.Sprintf( - `let expected: UFix128 = %s`, - tc.expected, - )).GetGlobal("expected") - + expected := interpreter.NewUFix128ValueFromBigInt(nil, fix128BigInt(tc.expected)) AssertValuesEqual(t, inter, expected, result) } }) diff --git a/interpreter/value.go b/interpreter/value.go index 5f5fb7132..d12d6af4d 100644 --- a/interpreter/value.go +++ b/interpreter/value.go @@ -23,6 +23,8 @@ import ( "github.com/onflow/atree" + fix "github.com/onflow/fixed-point" + "github.com/onflow/cadence/common" "github.com/onflow/cadence/sema" ) @@ -333,6 +335,12 @@ type FixedPointValue interface { NumberValue IntegerPart() NumberValue Scale() int + MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, + ) NumberValue } type AuthorizedValue interface { diff --git a/interpreter/value_bool.go b/interpreter/value_bool.go index 5d9705ed6..b72cc7615 100644 --- a/interpreter/value_bool.go +++ b/interpreter/value_bool.go @@ -193,22 +193,33 @@ func (BoolValue) DeepRemove(_ ValueRemoveContext, _ bool) { // NO-OP } -func (v BoolValue) GetMember(context MemberAccessibleContext, name string, memberKind common.DeclarationKind) Value { +func (v BoolValue) GetMember( + context MemberAccessibleContext, + name string, + memberKind common.DeclarationKind, + accessedReference ReferenceValue, +) Value { return GetMember( context, v, + accessedReference, name, memberKind, nil, ) } -func (v BoolValue) GetMethod(context MemberAccessibleContext, name string) FunctionValue { +func (v BoolValue) GetMethod( + context MemberAccessibleContext, + name string, + accessedReference ReferenceValue, +) FunctionValue { switch name { case sema.ToStringFunctionName: return NewBoundHostFunctionValue( context, v, + accessedReference, sema.ToStringFunctionType, NativeBoolValueToStringFunction, ) diff --git a/interpreter/value_fix128.go b/interpreter/value_fix128.go index b04a62a33..c6992de44 100644 --- a/interpreter/value_fix128.go +++ b/interpreter/value_fix128.go @@ -370,6 +370,43 @@ func (v Fix128Value) Mod(context NumberValueArithmeticContext, other NumberValue return NewFix128Value(context, valueGetter) } +func (v Fix128Value) MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, +) NumberValue { + f, ok := factor.(Fix128Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: factor.StaticType(context), + }) + } + + d, ok := divisor.(Fix128Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: divisor.StaticType(context), + }) + } + + valueGetter := func() fix.Fix128 { + result, err := fix.Fix128(v).FMD( + fix.Fix128(f), + fix.Fix128(d), + rounding, + ) + handleFixedpointError(err) + return result + } + + return NewFix128Value(context, valueGetter) +} + func (v Fix128Value) Less(context ValueComparisonContext, other ComparableValue) BoolValue { o, ok := other.(Fix128Value) if !ok { diff --git a/interpreter/value_fix64.go b/interpreter/value_fix64.go index dd3205069..7b7ea5fd2 100644 --- a/interpreter/value_fix64.go +++ b/interpreter/value_fix64.go @@ -395,6 +395,42 @@ func (v Fix64Value) Mod(context NumberValueArithmeticContext, other NumberValue) return v.Minus(context, truncatedQuotient.Mul(context, o)) } +func (v Fix64Value) MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, +) NumberValue { + f, ok := factor.(Fix64Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: factor.StaticType(context), + }) + } + + d, ok := divisor.(Fix64Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: divisor.StaticType(context), + }) + } + + valueGetter := func() int64 { + a := fix.Fix64(uint64(v)) + b := fix.Fix64(uint64(f)) + c := fix.Fix64(uint64(d)) + result, err := a.FMD(b, c, rounding) + handleFixedpointError(err) + return int64(result) + } + + return NewFix64Value(context, valueGetter) +} + func (v Fix64Value) Less(context ValueComparisonContext, other ComparableValue) BoolValue { o, ok := other.(Fix64Value) if !ok { diff --git a/interpreter/value_number.go b/interpreter/value_number.go index f6fd3a092..953c63797 100644 --- a/interpreter/value_number.go +++ b/interpreter/value_number.go @@ -21,6 +21,8 @@ package interpreter import ( "math/big" + fix "github.com/onflow/fixed-point" + "github.com/onflow/cadence/common" "github.com/onflow/cadence/errors" "github.com/onflow/cadence/sema" @@ -120,9 +122,23 @@ func getNumberValueFunctionMember( return NewBoundHostFunctionValue( context, v, + accessedReference, funcType, NativeFixedPointPowFunction, ) + + case sema.FixedPointNumericTypeMultiplyDivideFunctionName: + funcType, ok := sema.FixedPointMultiplyDivideFunctionTypes[typ] + if !ok { + return nil + } + return NewBoundHostFunctionValue( + context, + v, + accessedReference, + funcType, + NativeFixedPointMultiplyDivideFunction, + ) } return nil @@ -235,6 +251,26 @@ var NativeNumberSaturatingDivideFunction = NativeFunction( }, ) +var NativeFixedPointMultiplyDivideFunction = NativeFunction( + func( + context NativeFunctionContext, + _ TypeArgumentsIterator, + _ ArgumentTypesIterator, + receiver Value, + args []Value, + ) Value { + factor := AssertValueOfType[FixedPointValue](args[0]) + divisor := AssertValueOfType[FixedPointValue](args[1]) + var rounding fix.RoundingMode + if len(args) > 2 { + rounding = extractRoundingRule(args[2]) + } else { + rounding = fix.RoundTruncate + } + return receiver.(FixedPointValue).MultiplyDivide(context, factor, divisor, rounding) + }, +) + var NativeFixedPointPowFunction = NativeFunction( func( context NativeFunctionContext, diff --git a/interpreter/value_ufix128.go b/interpreter/value_ufix128.go index 022f0f996..b41ff8296 100644 --- a/interpreter/value_ufix128.go +++ b/interpreter/value_ufix128.go @@ -362,6 +362,43 @@ func (v UFix128Value) Mod(context NumberValueArithmeticContext, other NumberValu return NewUFix128Value(context, valueGetter) } +func (v UFix128Value) MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, +) NumberValue { + f, ok := factor.(UFix128Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: factor.StaticType(context), + }) + } + + d, ok := divisor.(UFix128Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: divisor.StaticType(context), + }) + } + + valueGetter := func() fix.UFix128 { + result, err := fix.UFix128(v).FMD( + fix.UFix128(f), + fix.UFix128(d), + rounding, + ) + handleFixedpointError(err) + return result + } + + return NewUFix128Value(context, valueGetter) +} + func (v UFix128Value) Pow(context NumberValueArithmeticContext, other Fix128Value) NumberValue { valueGetter := func() fix.UFix128 { result, err := fix.UFix128(v).Pow(fix.Fix128(other)) diff --git a/interpreter/value_ufix64.go b/interpreter/value_ufix64.go index 6f040ee84..1701201a4 100644 --- a/interpreter/value_ufix64.go +++ b/interpreter/value_ufix64.go @@ -393,6 +393,42 @@ func (v UFix64Value) Mod(context NumberValueArithmeticContext, other NumberValue return UFix64Value{UFix64Value: result} } +func (v UFix64Value) MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, +) NumberValue { + f, ok := factor.(UFix64Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: factor.StaticType(context), + }) + } + + d, ok := divisor.(UFix64Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: divisor.StaticType(context), + }) + } + + valueGetter := func() uint64 { + a := fix.UFix64(uint64(v.UFix64Value)) + b := fix.UFix64(uint64(f.UFix64Value)) + c := fix.UFix64(uint64(d.UFix64Value)) + result, err := a.FMD(b, c, rounding) + handleFixedpointError(err) + return uint64(result) + } + + return NewUFix64Value(context, valueGetter) +} + func (v UFix64Value) Pow(context NumberValueArithmeticContext, other Fix64Value) NumberValue { valueGetter := func() uint64 { a := fix.UFix64(uint64(v.UFix64Value)) diff --git a/runtime/rounding_rule_test.go b/runtime/rounding_rule_test.go index c14d8876d..f264d21d1 100644 --- a/runtime/rounding_rule_test.go +++ b/runtime/rounding_rule_test.go @@ -178,17 +178,17 @@ func TestRuntimeFix64ConversionWithRoundingRuleArgument(t *testing.T) { t.Parallel() - runtime := NewTestRuntime() - runtimeInterface := &TestRuntimeInterface{ - OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { - return json.Decode(nil, b) - }, - } - nextScriptLocation := NewScriptLocationGenerator() - t.Run("Fix128 to Fix64 with rounding", func(t *testing.T) { t.Parallel() + runtime := NewTestRuntime() + runtimeInterface := &TestRuntimeInterface{ + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + nextScriptLocation := NewScriptLocationGenerator() + const script = ` access(all) fun main(rule: RoundingRule): Fix64 { let x: Fix128 = 1.000000005000000000000000 @@ -216,6 +216,14 @@ func TestRuntimeFix64ConversionWithRoundingRuleArgument(t *testing.T) { t.Run("UFix128 to UFix64 with rounding", func(t *testing.T) { t.Parallel() + runtime := NewTestRuntime() + runtimeInterface := &TestRuntimeInterface{ + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + nextScriptLocation := NewScriptLocationGenerator() + const script = ` access(all) fun main(rule: RoundingRule): UFix64 { let x: UFix128 = 1.000000005000000000000000 diff --git a/sema/fixedpoint_test.go b/sema/fixedpoint_test.go index 205b90bab..7ac560b82 100644 --- a/sema/fixedpoint_test.go +++ b/sema/fixedpoint_test.go @@ -27,11 +27,111 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/cadence/common" "github.com/onflow/cadence/format" "github.com/onflow/cadence/sema" + "github.com/onflow/cadence/stdlib" . "github.com/onflow/cadence/test_utils/sema_utils" ) +func TestCheckFixedPointMultiplyDivide(t *testing.T) { + + t.Parallel() + + baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) + for _, value := range stdlib.InterpreterDefaultScriptStandardLibraryValues(nil) { + baseValueActivation.DeclareValue(value) + } + + parseAndCheckWithRoundingRule := func(t *testing.T, code string) (*sema.Checker, error) { + return ParseAndCheckWithOptions(t, + code, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + } + + for _, fixedPointType := range common.Concat( + sema.AllSignedFixedPointTypes, + sema.AllUnsignedFixedPointTypes, + ) { + + t.Run(fixedPointType.String(), func(t *testing.T) { + + t.Parallel() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + checker, err := parseAndCheckWithRoundingRule(t, + fmt.Sprintf( + ` + let a: %[1]s = 2.0 + let b: %[1]s = 3.0 + let c: %[1]s = 1.0 + let result = a.multiplyDivide(b, c, rounding: RoundingRule.towardZero) + `, + fixedPointType, + ), + ) + require.NoError(t, err) + + resultType := RequireGlobalValue(t, checker.Elaboration, "result") + assert.Equal(t, fixedPointType, resultType) + }) + + t.Run("invalid, wrong factor type", func(t *testing.T) { + t.Parallel() + + _, err := parseAndCheckWithRoundingRule(t, + fmt.Sprintf( + ` + let a: %[1]s = 2.0 + let b: Int = 3 + let c: %[1]s = 1.0 + let result = a.multiplyDivide(b, c, rounding: RoundingRule.towardZero) + `, + fixedPointType, + ), + ) + require.Error(t, err) + }) + + t.Run("valid, without rounding", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, + fmt.Sprintf( + ` + let a: %[1]s = 2.0 + let b: %[1]s = 3.0 + let c: %[1]s = 1.0 + let result = a.multiplyDivide(b, c) + `, + fixedPointType, + ), + ) + require.NoError(t, err) + }) + }) + } + + t.Run("not available on integer types", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let a: Int = 2 + let result = a.multiplyDivide(3, 1, rounding: RoundingRule.towardZero) + `) + require.Error(t, err) + }) +} + func TestCheckFixedPointPow(t *testing.T) { t.Parallel() diff --git a/sema/type.go b/sema/type.go index 3aa39a88c..3666e8122 100644 --- a/sema/type.go +++ b/sema/type.go @@ -1439,6 +1439,63 @@ func addFixedPointPowFunction( } } +const FixedPointNumericTypeMultiplyDivideFunctionName = "multiplyDivide" +const fixedPointNumericTypeMultiplyDivideFunctionDocString = ` +Returns self * factor / divisor, without intermediate rounding +` + +var FixedPointMultiplyDivideFunctionTypes = map[Type]*FunctionType{} + +func registerFixedPointMultiplyDivideFunction(t *FixedPointNumericType) { + FixedPointMultiplyDivideFunctionTypes[t] = &FunctionType{ + Purity: FunctionPurityView, + Parameters: []Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "factor", + TypeAnnotation: NewTypeAnnotation(t), + }, + { + Label: ArgumentLabelNotRequired, + Identifier: "divisor", + TypeAnnotation: NewTypeAnnotation(t), + }, + { + Label: "rounding", + Identifier: "rounding", + TypeAnnotation: RoundingRuleTypeAnnotation, + }, + }, + Arity: &Arity{Min: 2, Max: 3}, + ReturnTypeAnnotation: NewTypeAnnotation(t), + } +} + +func addFixedPointMultiplyDivideFunction( + t *FixedPointNumericType, + members map[string]MemberResolver, +) { + functionType := FixedPointMultiplyDivideFunctionTypes[t] + + members[FixedPointNumericTypeMultiplyDivideFunctionName] = MemberResolver{ + Kind: common.DeclarationKindFunction, + Resolve: func( + memoryGauge common.MemoryGauge, + _ string, + _ ast.HasPosition, + _ func(error), + ) *Member { + return NewPublicFunctionMember( + memoryGauge, + t, + FixedPointNumericTypeMultiplyDivideFunctionName, + functionType, + fixedPointNumericTypeMultiplyDivideFunctionDocString, + ) + }, + } +} + // NumericType represent all the types in the integer range // and non-fractional ranged types. type NumericType struct { @@ -1694,9 +1751,11 @@ var _ FractionalRangedType = &FixedPointNumericType{} var _ SaturatingArithmeticType = &FixedPointNumericType{} func NewFixedPointNumericType(typeName string) *FixedPointNumericType { - return &FixedPointNumericType{ + t := &FixedPointNumericType{ name: typeName, } + registerFixedPointMultiplyDivideFunction(t) + return t } func (t *FixedPointNumericType) Tag() TypeTag { @@ -1885,6 +1944,7 @@ func (t *FixedPointNumericType) GetMembers() map[string]MemberResolver { if _, ok := FixedPointPowFunctionTypes[t]; ok { addFixedPointPowFunction(t, computedMembers) } + addFixedPointMultiplyDivideFunction(t, computedMembers) computedMembers = withBuiltinMembers(t, computedMembers) t.memberResolvers.Store(&computedMembers) return computedMembers