Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,15 @@ public boolean equals(Object obj) {
return name.equals(other.name);
}
}

public boolean isInternal() {
return name.startsWith("#");
}

public int getCounter() {
if (!isInternal())
throw new IllegalStateException("Cannot get counter of non-internal variable");
int lastUnderscore = name.lastIndexOf('_');
return Integer.parseInt(name.substring(lastUnderscore + 1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;

public class ConstantFolding {
public class ExpressionFolding {

/**
* Performs constant folding on a derivation node by evaluating nodes with constant values. Returns a new derivation
* node representing the folding steps taken
* Performs expression folding on a derivation node by evaluating nodes when possible. Returns a new derivation node
* representing the folding steps taken
*/
public static ValDerivationNode fold(ValDerivationNode node) {
Expression exp = node.getValue();
Expand All @@ -35,7 +35,7 @@ public static ValDerivationNode fold(ValDerivationNode node) {
}

/**
* Folds a binary expression node if both children are constant values (e.g. 1 + 2 => 3)
* Folds a binary expression node (e.g. 1 + 2 => 3)
*/
private static ValDerivationNode foldBinary(ValDerivationNode node) {
BinaryExpression binExp = (BinaryExpression) node.getValue();
Expand Down Expand Up @@ -148,7 +148,7 @@ else if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) {
}

/**
* Folds a unary expression node if the child (operand) is a constant value (e.g. !true => false)
* Folds a unary expression node (e.g. !true => false)
*/
private static ValDerivationNode foldUnary(ValDerivationNode node) {
UnaryExpression unaryExp = (UnaryExpression) node.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public static ValDerivationNode simplify(Expression exp) {
*/
private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, Expression prevExp) {
// apply propagation and folding
ValDerivationNode prop = ConstantPropagation.propagate(prevExp, current);
ValDerivationNode fold = ConstantFolding.fold(prop);
ValDerivationNode prop = VariablePropagation.propagate(prevExp, current);
ValDerivationNode fold = ExpressionFolding.fold(prop);
ValDerivationNode simplified = simplifyValDerivationNode(fold);
Expression currExp = simplified.getValue();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import java.util.HashMap;
import java.util.Map;

public class ConstantPropagation {
public class VariablePropagation {

/**
* Performs constant propagation on an expression, by substituting variables with their constant values. Uses the
* VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing
* the propagation steps taken.
* Performs constant and variable propagation on an expression, by substituting variables. Uses the VariableResolver
* to extract variable equalities from the expression first. Returns a derivation node representing the propagation
* steps taken.
*/
public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin) {
Map<String, Expression> substitutions = VariableResolver.resolve(exp);
Expand All @@ -32,7 +32,7 @@ public static ValDerivationNode propagate(Expression exp, ValDerivationNode prev
}

/**
* Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
* Recursively performs propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
*/
private static ValDerivationNode propagateRecursive(Expression exp, Map<String, Expression> subs,
Map<String, DerivationNode> varOrigins) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ private static void resolveRecursive(Expression exp, Map<String, Expression> map
map.put(var.getName(), right.clone());
} else if (right instanceof Var var && left.isLiteral()) {
map.put(var.getName(), left.clone());
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
// to substitute internal variable with user-facing variable
if (leftVar.isInternal() && !rightVar.isInternal()) {
map.put(leftVar.getName(), right.clone());
} else if (rightVar.isInternal() && !leftVar.isInternal()) {
map.put(rightVar.getName(), left.clone());
} else if (leftVar.isInternal() && rightVar.isInternal()) {
// to substitute the lower-counter variable with the higher-counter one
boolean isLeftCounterLower = leftVar.getCounter() <= rightVar.getCounter();
Var lowerVar = isLeftCounterLower ? leftVar : rightVar;
Var higherVar = isLeftCounterLower ? rightVar : leftVar;
map.putIfAbsent(lowerVar.getName(), higherVar.clone());
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import java.util.stream.Stream;
import liquidjava.api.CommandLineLauncher;
import liquidjava.diagnostics.Diagnostics;
import liquidjava.diagnostics.errors.*;

import liquidjava.diagnostics.errors.LJError;
import org.junit.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,196 @@ void testShouldUnwrapNestedBooleanInEquality() {
"Boolean in equality should be unwrapped to show the computed comparison");
}

@Test
void testVarToVarPropagationWithInternalVariable() {
// Given: #x_0 == a && #x_0 > 5
// Expected: a > 5 (internal #x_0 substituted with user-facing a)

Expression varX0 = new Var("#x_0");
Expression varA = new Var("a");
Expression x0EqualsA = new BinaryExpression(varX0, "==", varA);
Expression x0Greater5 = new BinaryExpression(varX0, ">", new LiteralInt(5));
Expression fullExpression = new BinaryExpression(x0EqualsA, "&&", x0Greater5);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("a > 5", result.getValue().toString(),
"Internal variable #x_0 should be substituted with user-facing variable a");
}

@Test
void testVarToVarInternalToInternal() {
// Given: #a_1 == #b_2 && #b_2 == 5 && x == #a_1 + 1
// Expected: x == 5 + 1 = x == 6

Expression varA = new Var("#a_1");
Expression varB = new Var("#b_2");
Expression varX = new Var("x");
Expression five = new LiteralInt(5);
Expression aEqualsB = new BinaryExpression(varA, "==", varB);
Expression bEquals5 = new BinaryExpression(varB, "==", five);
Expression aPlus1 = new BinaryExpression(varA, "+", new LiteralInt(1));
Expression xEqualsAPlus1 = new BinaryExpression(varX, "==", aPlus1);
Expression firstAnd = new BinaryExpression(aEqualsB, "&&", bEquals5);
Expression fullExpression = new BinaryExpression(firstAnd, "&&", xEqualsAPlus1);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("x == 6", result.getValue().toString(),
"#a should resolve through #b to 5 across passes, then x == 5 + 1 = x == 6");
}

@Test
void testVarToVarDoesNotAffectUserFacingVariables() {
// Given: x == y && x > 5
// Expected: x == y && x > 5 (user-facing var-to-var should not be propagated)

Expression varX = new Var("x");
Expression varY = new Var("y");
Expression xEqualsY = new BinaryExpression(varX, "==", varY);
Expression xGreater5 = new BinaryExpression(varX, ">", new LiteralInt(5));
Expression fullExpression = new BinaryExpression(xEqualsY, "&&", xGreater5);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("x == y && x > 5", result.getValue().toString(),
"User-facing variable equalities should not trigger var-to-var propagation");
}

@Test
void testVarToVarRemovesRedundantEquality() {
// Given: #ret_1 == #b_0 - 100 && #b_0 == b && b >= -128 && b <= 127
// Expected: #ret_1 == b - 100 && b >= -128 && b <= 127 (#b_0 replaced with b, #b_0 == b removed)

Expression ret1 = new Var("#ret_1");
Expression b0 = new Var("#b_0");
Expression b = new Var("b");
Expression ret1EqB0Minus100 = new BinaryExpression(ret1, "==",
new BinaryExpression(b0, "-", new LiteralInt(100)));
Expression b0EqB = new BinaryExpression(b0, "==", b);
Expression bGeMinus128 = new BinaryExpression(b, ">=", new UnaryExpression("-", new LiteralInt(128)));
Expression bLe127 = new BinaryExpression(b, "<=", new LiteralInt(127));
Expression and1 = new BinaryExpression(ret1EqB0Minus100, "&&", b0EqB);
Expression and2 = new BinaryExpression(bGeMinus128, "&&", bLe127);
Expression fullExpression = new BinaryExpression(and1, "&&", and2);

// When
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

// Then
assertNotNull(result, "Result should not be null");
assertEquals("#ret_1 == b - 100 && b >= -128 && b <= 127", result.getValue().toString(),
"Internal variable #b_0 should be replaced with b and redundant equality removed");
assertNotNull(result.getOrigin(), "Origin should be present showing the var-to-var derivation");
}

@Test
void testInternalToInternalReducesRedundantVariable() {
// Given: #a_3 == #b_7 && #a_3 > 5
// Expected: #b_7 > 5 (#a_3 has lower counter, so #a_3 -> #b_7)

Expression a3 = new Var("#a_3");
Expression b7 = new Var("#b_7");
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
Expression a3Greater5 = new BinaryExpression(a3, ">", new LiteralInt(5));
Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", a3Greater5);

ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

assertNotNull(result);
assertEquals("#b_7 > 5", result.getValue().toString(),
"#a_3 (lower counter) should be substituted with #b_7 (higher counter)");
}

@Test
void testInternalToInternalChainWithUserFacingVariableUserFacingFirst() {
// Given: #b_7 == x && #a_3 == #b_7 && x > 0
// Expected: x > 0 (#b_7 -> x (user-facing); #a_3 has lower counter so #a_3 -> #b_7)

Expression a3 = new Var("#a_3");
Expression b7 = new Var("#b_7");
Expression x = new Var("x");
Expression b7EqualsX = new BinaryExpression(b7, "==", x);
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
Expression xGreater0 = new BinaryExpression(x, ">", new LiteralInt(0));
Expression and1 = new BinaryExpression(b7EqualsX, "&&", a3EqualsB7);
Expression fullExpression = new BinaryExpression(and1, "&&", xGreater0);

ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

assertNotNull(result);
assertEquals("x > 0", result.getValue().toString(),
"Both internal variables should be eliminated via chain resolution");
}

@Test
void testInternalToInternalChainWithUserFacingVariableInternalFirst() {
// Given: #a_3 == #b_7 && #b_7 == x && x > 0
// Expected: x > 0 (#a_3 has lower counter so #a_3 -> #b_7; #b_7 -> x (user-facing) overwrites)

Expression a3 = new Var("#a_3");
Expression b7 = new Var("#b_7");
Expression x = new Var("x");
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
Expression b7EqualsX = new BinaryExpression(b7, "==", x);
Expression xGreater0 = new BinaryExpression(x, ">", new LiteralInt(0));
Expression and1 = new BinaryExpression(a3EqualsB7, "&&", b7EqualsX);
Expression fullExpression = new BinaryExpression(and1, "&&", xGreater0);

ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

assertNotNull(result);
assertEquals("x > 0", result.getValue().toString(),
"Both internal variables should be eliminated via fixed-point iteration");
}

@Test
void testInternalToInternalBothResolvingToLiteral() {
// Given: #a_3 == #b_7 && #b_7 == 5
// Expected: 5 == 5 && 5 == 5 (#a_3 has lower counter so #a_3 -> #b_7; #b_7 -> 5)

Expression a3 = new Var("#a_3");
Expression b7 = new Var("#b_7");
Expression five = new LiteralInt(5);
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
Expression b7Equals5 = new BinaryExpression(b7, "==", five);
Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", b7Equals5);

ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

assertNotNull(result);
assertEquals("5 == 5 && 5 == 5", result.getValue().toString(),
"#a_3 -> #b_7 -> 5 and #b_7 -> 5; both equalities collapse to 5 == 5");
}

@Test
void testInternalToInternalNoFurtherResolution() {
// Given: #a_3 == #b_7 && #b_7 + 1 > 0
// Expected: #b_7 + 1 > 0 (#a_3 has lower counter, so #a_3 -> #b_7)

Expression a3 = new Var("#a_3");
Expression b7 = new Var("#b_7");
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
Expression b7Plus1 = new BinaryExpression(b7, "+", new LiteralInt(1));
Expression b7Plus1Greater0 = new BinaryExpression(b7Plus1, ">", new LiteralInt(0));
Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", b7Plus1Greater0);

ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

assertNotNull(result);
assertEquals("#b_7 + 1 > 0", result.getValue().toString(),
"#a_3 (lower counter) replaced by #b_7 (higher counter); equality collapses to trivial");
}

/**
* Helper method to compare two derivation nodes recursively
*/
Expand Down
Loading