Skip to content

Commit 7f0e0af

Browse files
authored
Internal Variable Substitution (#181)
1 parent 1272e50 commit 7f0e0af

File tree

7 files changed

+227
-13
lines changed

7 files changed

+227
-13
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Var.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,15 @@ public boolean equals(Object obj) {
7070
return name.equals(other.name);
7171
}
7272
}
73+
74+
public boolean isInternal() {
75+
return name.startsWith("#");
76+
}
77+
78+
public int getCounter() {
79+
if (!isInternal())
80+
throw new IllegalStateException("Cannot get counter of non-internal variable");
81+
int lastUnderscore = name.lastIndexOf('_');
82+
return Integer.parseInt(name.substring(lastUnderscore + 1));
83+
}
7384
}

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java renamed to liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionFolding.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
1313
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1414

15-
public class ConstantFolding {
15+
public class ExpressionFolding {
1616

1717
/**
18-
* Performs constant folding on a derivation node by evaluating nodes with constant values. Returns a new derivation
19-
* node representing the folding steps taken
18+
* Performs expression folding on a derivation node by evaluating nodes when possible. Returns a new derivation node
19+
* representing the folding steps taken
2020
*/
2121
public static ValDerivationNode fold(ValDerivationNode node) {
2222
Expression exp = node.getValue();
@@ -35,7 +35,7 @@ public static ValDerivationNode fold(ValDerivationNode node) {
3535
}
3636

3737
/**
38-
* Folds a binary expression node if both children are constant values (e.g. 1 + 2 => 3)
38+
* Folds a binary expression node (e.g. 1 + 2 => 3)
3939
*/
4040
private static ValDerivationNode foldBinary(ValDerivationNode node) {
4141
BinaryExpression binExp = (BinaryExpression) node.getValue();
@@ -148,7 +148,7 @@ else if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) {
148148
}
149149

150150
/**
151-
* Folds a unary expression node if the child (operand) is a constant value (e.g. !true => false)
151+
* Folds a unary expression node (e.g. !true => false)
152152
*/
153153
private static ValDerivationNode foldUnary(ValDerivationNode node) {
154154
UnaryExpression unaryExp = (UnaryExpression) node.getValue();

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ public static ValDerivationNode simplify(Expression exp) {
2727
*/
2828
private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, Expression prevExp) {
2929
// apply propagation and folding
30-
ValDerivationNode prop = ConstantPropagation.propagate(prevExp, current);
31-
ValDerivationNode fold = ConstantFolding.fold(prop);
30+
ValDerivationNode prop = VariablePropagation.propagate(prevExp, current);
31+
ValDerivationNode fold = ExpressionFolding.fold(prop);
3232
ValDerivationNode simplified = simplifyValDerivationNode(fold);
3333
Expression currExp = simplified.getValue();
3434

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java renamed to liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
import java.util.HashMap;
1414
import java.util.Map;
1515

16-
public class ConstantPropagation {
16+
public class VariablePropagation {
1717

1818
/**
19-
* Performs constant propagation on an expression, by substituting variables with their constant values. Uses the
20-
* VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing
21-
* the propagation steps taken.
19+
* Performs constant and variable propagation on an expression, by substituting variables. Uses the VariableResolver
20+
* to extract variable equalities from the expression first. Returns a derivation node representing the propagation
21+
* steps taken.
2222
*/
2323
public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin) {
2424
Map<String, Expression> substitutions = VariableResolver.resolve(exp);
@@ -32,7 +32,7 @@ public static ValDerivationNode propagate(Expression exp, ValDerivationNode prev
3232
}
3333

3434
/**
35-
* Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
35+
* Recursively performs propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
3636
*/
3737
private static ValDerivationNode propagateRecursive(Expression exp, Map<String, Expression> subs,
3838
Map<String, DerivationNode> varOrigins) {

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ private static void resolveRecursive(Expression exp, Map<String, Expression> map
5252
map.put(var.getName(), right.clone());
5353
} else if (right instanceof Var var && left.isLiteral()) {
5454
map.put(var.getName(), left.clone());
55+
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
56+
// to substitute internal variable with user-facing variable
57+
if (leftVar.isInternal() && !rightVar.isInternal()) {
58+
map.put(leftVar.getName(), right.clone());
59+
} else if (rightVar.isInternal() && !leftVar.isInternal()) {
60+
map.put(rightVar.getName(), left.clone());
61+
} else if (leftVar.isInternal() && rightVar.isInternal()) {
62+
// to substitute the lower-counter variable with the higher-counter one
63+
boolean isLeftCounterLower = leftVar.getCounter() <= rightVar.getCounter();
64+
Var lowerVar = isLeftCounterLower ? leftVar : rightVar;
65+
Var higherVar = isLeftCounterLower ? rightVar : leftVar;
66+
map.putIfAbsent(lowerVar.getName(), higherVar.clone());
67+
}
5568
}
5669
}
5770
}

liquidjava-verifier/src/test/java/liquidjava/api/tests/TestExamples.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import java.util.stream.Stream;
1212
import liquidjava.api.CommandLineLauncher;
1313
import liquidjava.diagnostics.Diagnostics;
14+
import liquidjava.diagnostics.errors.*;
1415

15-
import liquidjava.diagnostics.errors.LJError;
1616
import org.junit.Test;
1717
import org.junit.jupiter.params.ParameterizedTest;
1818
import org.junit.jupiter.params.provider.MethodSource;

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,196 @@ void testShouldUnwrapNestedBooleanInEquality() {
687687
"Boolean in equality should be unwrapped to show the computed comparison");
688688
}
689689

690+
@Test
691+
void testVarToVarPropagationWithInternalVariable() {
692+
// Given: #x_0 == a && #x_0 > 5
693+
// Expected: a > 5 (internal #x_0 substituted with user-facing a)
694+
695+
Expression varX0 = new Var("#x_0");
696+
Expression varA = new Var("a");
697+
Expression x0EqualsA = new BinaryExpression(varX0, "==", varA);
698+
Expression x0Greater5 = new BinaryExpression(varX0, ">", new LiteralInt(5));
699+
Expression fullExpression = new BinaryExpression(x0EqualsA, "&&", x0Greater5);
700+
701+
// When
702+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
703+
704+
// Then
705+
assertNotNull(result, "Result should not be null");
706+
assertEquals("a > 5", result.getValue().toString(),
707+
"Internal variable #x_0 should be substituted with user-facing variable a");
708+
}
709+
710+
@Test
711+
void testVarToVarInternalToInternal() {
712+
// Given: #a_1 == #b_2 && #b_2 == 5 && x == #a_1 + 1
713+
// Expected: x == 5 + 1 = x == 6
714+
715+
Expression varA = new Var("#a_1");
716+
Expression varB = new Var("#b_2");
717+
Expression varX = new Var("x");
718+
Expression five = new LiteralInt(5);
719+
Expression aEqualsB = new BinaryExpression(varA, "==", varB);
720+
Expression bEquals5 = new BinaryExpression(varB, "==", five);
721+
Expression aPlus1 = new BinaryExpression(varA, "+", new LiteralInt(1));
722+
Expression xEqualsAPlus1 = new BinaryExpression(varX, "==", aPlus1);
723+
Expression firstAnd = new BinaryExpression(aEqualsB, "&&", bEquals5);
724+
Expression fullExpression = new BinaryExpression(firstAnd, "&&", xEqualsAPlus1);
725+
726+
// When
727+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
728+
729+
// Then
730+
assertNotNull(result, "Result should not be null");
731+
assertEquals("x == 6", result.getValue().toString(),
732+
"#a should resolve through #b to 5 across passes, then x == 5 + 1 = x == 6");
733+
}
734+
735+
@Test
736+
void testVarToVarDoesNotAffectUserFacingVariables() {
737+
// Given: x == y && x > 5
738+
// Expected: x == y && x > 5 (user-facing var-to-var should not be propagated)
739+
740+
Expression varX = new Var("x");
741+
Expression varY = new Var("y");
742+
Expression xEqualsY = new BinaryExpression(varX, "==", varY);
743+
Expression xGreater5 = new BinaryExpression(varX, ">", new LiteralInt(5));
744+
Expression fullExpression = new BinaryExpression(xEqualsY, "&&", xGreater5);
745+
746+
// When
747+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
748+
749+
// Then
750+
assertNotNull(result, "Result should not be null");
751+
assertEquals("x == y && x > 5", result.getValue().toString(),
752+
"User-facing variable equalities should not trigger var-to-var propagation");
753+
}
754+
755+
@Test
756+
void testVarToVarRemovesRedundantEquality() {
757+
// Given: #ret_1 == #b_0 - 100 && #b_0 == b && b >= -128 && b <= 127
758+
// Expected: #ret_1 == b - 100 && b >= -128 && b <= 127 (#b_0 replaced with b, #b_0 == b removed)
759+
760+
Expression ret1 = new Var("#ret_1");
761+
Expression b0 = new Var("#b_0");
762+
Expression b = new Var("b");
763+
Expression ret1EqB0Minus100 = new BinaryExpression(ret1, "==",
764+
new BinaryExpression(b0, "-", new LiteralInt(100)));
765+
Expression b0EqB = new BinaryExpression(b0, "==", b);
766+
Expression bGeMinus128 = new BinaryExpression(b, ">=", new UnaryExpression("-", new LiteralInt(128)));
767+
Expression bLe127 = new BinaryExpression(b, "<=", new LiteralInt(127));
768+
Expression and1 = new BinaryExpression(ret1EqB0Minus100, "&&", b0EqB);
769+
Expression and2 = new BinaryExpression(bGeMinus128, "&&", bLe127);
770+
Expression fullExpression = new BinaryExpression(and1, "&&", and2);
771+
772+
// When
773+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
774+
775+
// Then
776+
assertNotNull(result, "Result should not be null");
777+
assertEquals("#ret_1 == b - 100 && b >= -128 && b <= 127", result.getValue().toString(),
778+
"Internal variable #b_0 should be replaced with b and redundant equality removed");
779+
assertNotNull(result.getOrigin(), "Origin should be present showing the var-to-var derivation");
780+
}
781+
782+
@Test
783+
void testInternalToInternalReducesRedundantVariable() {
784+
// Given: #a_3 == #b_7 && #a_3 > 5
785+
// Expected: #b_7 > 5 (#a_3 has lower counter, so #a_3 -> #b_7)
786+
787+
Expression a3 = new Var("#a_3");
788+
Expression b7 = new Var("#b_7");
789+
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
790+
Expression a3Greater5 = new BinaryExpression(a3, ">", new LiteralInt(5));
791+
Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", a3Greater5);
792+
793+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
794+
795+
assertNotNull(result);
796+
assertEquals("#b_7 > 5", result.getValue().toString(),
797+
"#a_3 (lower counter) should be substituted with #b_7 (higher counter)");
798+
}
799+
800+
@Test
801+
void testInternalToInternalChainWithUserFacingVariableUserFacingFirst() {
802+
// Given: #b_7 == x && #a_3 == #b_7 && x > 0
803+
// Expected: x > 0 (#b_7 -> x (user-facing); #a_3 has lower counter so #a_3 -> #b_7)
804+
805+
Expression a3 = new Var("#a_3");
806+
Expression b7 = new Var("#b_7");
807+
Expression x = new Var("x");
808+
Expression b7EqualsX = new BinaryExpression(b7, "==", x);
809+
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
810+
Expression xGreater0 = new BinaryExpression(x, ">", new LiteralInt(0));
811+
Expression and1 = new BinaryExpression(b7EqualsX, "&&", a3EqualsB7);
812+
Expression fullExpression = new BinaryExpression(and1, "&&", xGreater0);
813+
814+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
815+
816+
assertNotNull(result);
817+
assertEquals("x > 0", result.getValue().toString(),
818+
"Both internal variables should be eliminated via chain resolution");
819+
}
820+
821+
@Test
822+
void testInternalToInternalChainWithUserFacingVariableInternalFirst() {
823+
// Given: #a_3 == #b_7 && #b_7 == x && x > 0
824+
// Expected: x > 0 (#a_3 has lower counter so #a_3 -> #b_7; #b_7 -> x (user-facing) overwrites)
825+
826+
Expression a3 = new Var("#a_3");
827+
Expression b7 = new Var("#b_7");
828+
Expression x = new Var("x");
829+
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
830+
Expression b7EqualsX = new BinaryExpression(b7, "==", x);
831+
Expression xGreater0 = new BinaryExpression(x, ">", new LiteralInt(0));
832+
Expression and1 = new BinaryExpression(a3EqualsB7, "&&", b7EqualsX);
833+
Expression fullExpression = new BinaryExpression(and1, "&&", xGreater0);
834+
835+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
836+
837+
assertNotNull(result);
838+
assertEquals("x > 0", result.getValue().toString(),
839+
"Both internal variables should be eliminated via fixed-point iteration");
840+
}
841+
842+
@Test
843+
void testInternalToInternalBothResolvingToLiteral() {
844+
// Given: #a_3 == #b_7 && #b_7 == 5
845+
// Expected: 5 == 5 && 5 == 5 (#a_3 has lower counter so #a_3 -> #b_7; #b_7 -> 5)
846+
847+
Expression a3 = new Var("#a_3");
848+
Expression b7 = new Var("#b_7");
849+
Expression five = new LiteralInt(5);
850+
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
851+
Expression b7Equals5 = new BinaryExpression(b7, "==", five);
852+
Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", b7Equals5);
853+
854+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
855+
856+
assertNotNull(result);
857+
assertEquals("5 == 5 && 5 == 5", result.getValue().toString(),
858+
"#a_3 -> #b_7 -> 5 and #b_7 -> 5; both equalities collapse to 5 == 5");
859+
}
860+
861+
@Test
862+
void testInternalToInternalNoFurtherResolution() {
863+
// Given: #a_3 == #b_7 && #b_7 + 1 > 0
864+
// Expected: #b_7 + 1 > 0 (#a_3 has lower counter, so #a_3 -> #b_7)
865+
866+
Expression a3 = new Var("#a_3");
867+
Expression b7 = new Var("#b_7");
868+
Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7);
869+
Expression b7Plus1 = new BinaryExpression(b7, "+", new LiteralInt(1));
870+
Expression b7Plus1Greater0 = new BinaryExpression(b7Plus1, ">", new LiteralInt(0));
871+
Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", b7Plus1Greater0);
872+
873+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
874+
875+
assertNotNull(result);
876+
assertEquals("#b_7 + 1 > 0", result.getValue().toString(),
877+
"#a_3 (lower counter) replaced by #b_7 (higher counter); equality collapses to trivial");
878+
}
879+
690880
/**
691881
* Helper method to compare two derivation nodes recursively
692882
*/

0 commit comments

Comments
 (0)