diff --git a/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java index 385023da7..fb46d4747 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import dev.cel.checker.CelCheckerBuilder; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelIssue; @@ -316,27 +317,38 @@ public static ImmutableList genRange(long end) { return builder.build(); } + private static class RuntimeEqualityObjectWrapper { + private final Object object; + private final int hashCode; + private final RuntimeEquality runtimeEquality; + + RuntimeEqualityObjectWrapper(Object object, RuntimeEquality runtimeEquality) { + this.object = object; + this.runtimeEquality = runtimeEquality; + this.hashCode = runtimeEquality.hashCode(object); + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof RuntimeEqualityObjectWrapper)) { + return false; + } + return runtimeEquality.objectEquals(object, ((RuntimeEqualityObjectWrapper) obj).object); + } + } + private static ImmutableList distinct( Collection list, RuntimeEquality runtimeEquality) { - // TODO Optimize this method, which currently has the O(N^2) complexity. int size = list.size(); ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(size); - List theList; - if (list instanceof List) { - theList = (List) list; - } else { - theList = ImmutableList.copyOf(list); - } - for (int i = 0; i < size; i++) { - Object element = theList.get(i); - boolean found = false; - for (int j = 0; j < i; j++) { - if (runtimeEquality.objectEquals(element, theList.get(j))) { - found = true; - break; - } - } - if (!found) { + Set distinctValues = Sets.newHashSetWithExpectedSize(size); + for (Object element : list) { + if (distinctValues.add(new RuntimeEqualityObjectWrapper(element, runtimeEquality))) { builder.add(element); } } diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java b/runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java index d0b2fcd6b..912462e00 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java @@ -223,6 +223,41 @@ public boolean objectEquals(Object x, Object y) { return Objects.equals(x, y); } + /** + * Returns the hash code consistent with the {@link #objectEquals(Object, Object)} method. For + * example, {@code hashCode(1) == hashCode(1.0)} since {@code objectEquals(1, 1.0)} is true. + */ + public int hashCode(Object object) { + if (object == null) { + return 0; + } + + if (celOptions.disableCelStandardEquality()) { + return Objects.hashCode(object); + } + + object = runtimeHelpers.adaptValue(object); + if (object instanceof Number) { + return Double.hashCode(((Number) object).doubleValue()); + } + if (object instanceof Iterable) { + int h = 1; + Iterable iter = (Iterable) object; + for (Object elem : iter) { + h = h * 31 + hashCode(elem); + } + return h; + } + if (object instanceof Map) { + int h = 0; + for (Map.Entry entry : ((Map) object).entrySet()) { + h += hashCode(entry.getKey()) ^ hashCode(entry.getValue()); + } + return h; + } + return Objects.hashCode(object); + } + private static Optional doubleToUnsignedLossless(Number v) { Optional conv = RuntimeHelpers.doubleToUnsignedChecked(v.doubleValue()); return conv.map(ul -> ul.longValue() == v.doubleValue() ? ul : null); diff --git a/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java b/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java index d942295ba..00e55873c 100644 --- a/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java +++ b/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java @@ -14,8 +14,12 @@ package dev.cel.runtime; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.UnsignedLong; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.common.CelOptions; import dev.cel.expr.conformance.proto2.TestAllTypes; @@ -25,6 +29,29 @@ @RunWith(TestParameterInjector.class) public final class RuntimeEqualityTest { + @Test + public void objectEquals_and_hashCode() { + RuntimeEquality runtimeEquality = + RuntimeEquality.create(RuntimeHelpers.create(), CelOptions.DEFAULT); + assertEqualityAndHashCode(runtimeEquality, 1, 1); + assertEqualityAndHashCode(runtimeEquality, 2, 2L); + assertEqualityAndHashCode(runtimeEquality, 3, 3.0); + assertEqualityAndHashCode(runtimeEquality, 4, UnsignedLong.valueOf(4)); + assertEqualityAndHashCode( + runtimeEquality, + ImmutableList.of(1, 2, 3), + ImmutableList.of(1.0, 2L, UnsignedLong.valueOf(3))); + assertEqualityAndHashCode( + runtimeEquality, + ImmutableMap.of("a", 1, "b", 2), + ImmutableMap.of("a", 1L, "b", UnsignedLong.valueOf(2))); + } + + private void assertEqualityAndHashCode(RuntimeEquality runtimeEquality, Object obj1, Object obj2) { + assertThat(runtimeEquality.objectEquals(obj1, obj2)).isTrue(); + assertThat(runtimeEquality.hashCode(obj1)).isEqualTo(runtimeEquality.hashCode(obj2)); + } + @Test public void objectEquals_messageLite_throws() { RuntimeEquality runtimeEquality =