Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import dev.cel.checker.CelCheckerBuilder;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelIssue;
Expand Down Expand Up @@ -316,27 +317,38 @@ public static ImmutableList<Long> genRange(long end) {
return builder.build();
}

private static class RuntimeEqualityObjectWrapper {
private final Object object;
private final int hashCode;
private final RuntimeEquality runtimeEquality;

RuntimeEqualityObjectWrapper(Object object, RuntimeEquality runtimeEquality) {
this.object = object;
this.runtimeEquality = runtimeEquality;
this.hashCode = runtimeEquality.hashCode(object);
}

@Override
public int hashCode() {
return hashCode;
}

@Override
public boolean equals(Object obj) {
if (!(obj instanceof RuntimeEqualityObjectWrapper)) {
return false;
}
return runtimeEquality.objectEquals(object, ((RuntimeEqualityObjectWrapper) obj).object);
}
}

private static ImmutableList<Object> distinct(
Collection<Object> list, RuntimeEquality runtimeEquality) {
// TODO Optimize this method, which currently has the O(N^2) complexity.
int size = list.size();
ImmutableList.Builder<Object> builder = ImmutableList.builderWithExpectedSize(size);
List<Object> theList;
if (list instanceof List) {
theList = (List<Object>) list;
} else {
theList = ImmutableList.copyOf(list);
}
for (int i = 0; i < size; i++) {
Object element = theList.get(i);
boolean found = false;
for (int j = 0; j < i; j++) {
if (runtimeEquality.objectEquals(element, theList.get(j))) {
found = true;
break;
}
}
if (!found) {
Set<RuntimeEqualityObjectWrapper> distinctValues = Sets.newHashSetWithExpectedSize(size);
for (Object element : list) {
if (distinctValues.add(new RuntimeEqualityObjectWrapper(element, runtimeEquality))) {
builder.add(element);
}
}
Expand Down
35 changes: 35 additions & 0 deletions runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,41 @@ public boolean objectEquals(Object x, Object y) {
return Objects.equals(x, y);
}

/**
* Returns the hash code consistent with the {@link #objectEquals(Object, Object)} method. For
* example, {@code hashCode(1) == hashCode(1.0)} since {@code objectEquals(1, 1.0)} is true.
*/
public int hashCode(Object object) {
if (object == null) {
return 0;
}

if (celOptions.disableCelStandardEquality()) {
return Objects.hashCode(object);
}

object = runtimeHelpers.adaptValue(object);
if (object instanceof Number) {
return Double.hashCode(((Number) object).doubleValue());
}
if (object instanceof Iterable) {
int h = 1;
Iterable<?> iter = (Iterable<?>) object;
for (Object elem : iter) {
h = h * 31 + hashCode(elem);
}
return h;
}
if (object instanceof Map) {
int h = 0;
for (Map.Entry<?, ?> entry : ((Map<?, ?>) object).entrySet()) {
h += hashCode(entry.getKey()) ^ hashCode(entry.getValue());
}
return h;
}
return Objects.hashCode(object);
}

private static Optional<UnsignedLong> doubleToUnsignedLossless(Number v) {
Optional<UnsignedLong> conv = RuntimeHelpers.doubleToUnsignedChecked(v.doubleValue());
return conv.map(ul -> ul.longValue() == v.doubleValue() ? ul : null);
Expand Down
27 changes: 27 additions & 0 deletions runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@

package dev.cel.runtime;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.UnsignedLong;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import dev.cel.common.CelOptions;
import dev.cel.expr.conformance.proto2.TestAllTypes;
Expand All @@ -25,6 +29,29 @@
@RunWith(TestParameterInjector.class)
public final class RuntimeEqualityTest {

@Test
public void objectEquals_and_hashCode() {
RuntimeEquality runtimeEquality =
RuntimeEquality.create(RuntimeHelpers.create(), CelOptions.DEFAULT);
assertEqualityAndHashCode(runtimeEquality, 1, 1);
assertEqualityAndHashCode(runtimeEquality, 2, 2L);
assertEqualityAndHashCode(runtimeEquality, 3, 3.0);
assertEqualityAndHashCode(runtimeEquality, 4, UnsignedLong.valueOf(4));
assertEqualityAndHashCode(
runtimeEquality,
ImmutableList.of(1, 2, 3),
ImmutableList.of(1.0, 2L, UnsignedLong.valueOf(3)));
assertEqualityAndHashCode(
runtimeEquality,
ImmutableMap.of("a", 1, "b", 2),
ImmutableMap.of("a", 1L, "b", UnsignedLong.valueOf(2)));
}

private void assertEqualityAndHashCode(RuntimeEquality runtimeEquality, Object obj1, Object obj2) {
assertThat(runtimeEquality.objectEquals(obj1, obj2)).isTrue();
assertThat(runtimeEquality.hashCode(obj1)).isEqualTo(runtimeEquality.hashCode(obj2));
}

@Test
public void objectEquals_messageLite_throws() {
RuntimeEquality runtimeEquality =
Expand Down
Loading