diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCArithmeticSimplification.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCArithmeticSimplification.java new file mode 100644 index 00000000..6b48459e --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCArithmeticSimplification.java @@ -0,0 +1,320 @@ +package liquidjava.rj_language.opt; + +import java.util.ArrayList; +import java.util.List; + +import liquidjava.processor.SimplifiedVCImplication; +import liquidjava.processor.VCImplication; +import liquidjava.rj_language.Predicate; +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.GroupExpression; +import liquidjava.rj_language.ast.Ite; +import liquidjava.rj_language.ast.LiteralInt; +import liquidjava.rj_language.ast.LiteralReal; +import liquidjava.rj_language.ast.UnaryExpression; + +/** + * Simplifies VCImplication chains by applying arithmetic identities inside refinements + */ +public class VCArithmeticSimplification { + + /** + * Applies the first arithmetic simplification available in a VC chain + */ + public static VCImplication apply(VCImplication implication) { + if (implication == null) + return null; + + return apply(implication, List.of()); + } + + private static VCImplication apply(VCImplication implication, List nonZeroExpressions) { + if (implication == null) + return null; + + Expression expression = implication.getRefinement().getExpression(); + Expression simplified = simplify(expression, nonZeroExpressions); + if (!expression.equals(simplified)) { + VCImplication result = new SimplifiedVCImplication(implication, new Predicate(simplified), implication); + result.setNext(implication.getNext() == null ? null : implication.getNext().clone()); + return result; + } + + List nextNoneZeroExpressions = new ArrayList<>(nonZeroExpressions); + addNonZeroExpression(implication.getRefinement().getExpression(), nextNoneZeroExpressions); + + VCImplication next = apply(implication.getNext(), nextNoneZeroExpressions); + if (implication.getNext() == null || implication.getNext().equals(next)) + return implication; + + VCImplication result = implication.copyWithRefinement(implication.getRefinement().clone()); + result.setNext(next); + return result; + } + + /** + * Simplifies the first arithmetic identity found inside an expression + */ + private static Expression simplify(Expression expression, List nonZeroExpressions) { + if (expression instanceof BinaryExpression binary) + return simplifyBinary(binary, nonZeroExpressions); + if (expression instanceof UnaryExpression unary) + return simplifyUnary(unary, nonZeroExpressions); + if (expression instanceof Ite ite) + return simplifyIte(ite, nonZeroExpressions); + if (expression instanceof GroupExpression group) + return simplifyGroup(group, nonZeroExpressions); + return expression.clone(); + } + + /** + * Simplifies a binary expression by visiting operands before the current node + */ + private static Expression simplifyBinary(BinaryExpression binary, List nonZeroExpressions) { + Expression left = binary.getFirstOperand(); + Expression simplifiedLeft = simplify(left, nonZeroExpressions); + if (!left.equals(simplifiedLeft)) + return new BinaryExpression(simplifiedLeft, binary.getOperator(), binary.getSecondOperand().clone()); + + Expression right = binary.getSecondOperand(); + Expression simplifiedRight = simplify(right, nonZeroExpressions); + if (!right.equals(simplifiedRight)) + return new BinaryExpression(left.clone(), binary.getOperator(), simplifiedRight); + + Expression simplifiedBinary = simplifyLocalBinary(left, right, binary.getOperator(), nonZeroExpressions); + if (simplifiedBinary != null) + return simplifiedBinary; + + return new BinaryExpression(left.clone(), binary.getOperator(), right.clone()); + } + + /** + * Simplifies a unary expression by visiting its operand before the current node + */ + private static Expression simplifyUnary(UnaryExpression unary, List nonZeroExpressions) { + Expression operand = unary.getExpression(); + Expression simplifiedOperand = simplify(operand, nonZeroExpressions); + if (!operand.equals(simplifiedOperand)) + return new UnaryExpression(unary.getOp(), simplifiedOperand); + + // -(-x) -> x + if ("-".equals(unary.getOp()) && isNegation(operand)) + return negatedExpression(operand).clone(); + + return new UnaryExpression(unary.getOp(), operand.clone()); + } + + /** + * Simplifies a ternary expression by visiting condition, then branch, and else branch + */ + private static Expression simplifyIte(Ite ite, List nonZeroExpressions) { + Expression condition = ite.getCondition(); + Expression simplifiedCondition = simplify(condition, nonZeroExpressions); + if (!condition.equals(simplifiedCondition)) + return new Ite(simplifiedCondition, ite.getThen().clone(), ite.getElse().clone()); + + Expression thenExpression = ite.getThen(); + Expression simplifiedThen = simplify(thenExpression, nonZeroExpressions); + if (!thenExpression.equals(simplifiedThen)) + return new Ite(condition.clone(), simplifiedThen, ite.getElse().clone()); + + Expression elseExpression = ite.getElse(); + Expression simplifiedElse = simplify(elseExpression, nonZeroExpressions); + if (!elseExpression.equals(simplifiedElse)) + return new Ite(condition.clone(), thenExpression.clone(), simplifiedElse); + + return new Ite(condition.clone(), thenExpression.clone(), elseExpression.clone()); + } + + /** + * Simplifies an expression wrapped in parentheses while preserving the group node + */ + private static Expression simplifyGroup(GroupExpression group, List nonZeroExpressions) { + Expression expression = group.getExpression(); + Expression simplified = simplify(expression, nonZeroExpressions); + if (!expression.equals(simplified)) + return new GroupExpression(simplified); + return group.clone(); + } + + /** + * Dispatches a local binary arithmetic identity by operator + */ + private static Expression simplifyLocalBinary(Expression left, Expression right, String op, + List nonZeroExpressions) { + return switch (op) { + case "+" -> simplifyAddition(left, right); + case "-" -> simplifySubtraction(left, right); + case "*" -> simplifyMultiplication(left, right); + case "/" -> simplifyDivision(left, right, nonZeroExpressions); + case "%" -> simplifyModulo(left, right, nonZeroExpressions); + default -> null; + }; + } + + /** + * Applies addition identities involving zero and unary negation + */ + private static Expression simplifyAddition(Expression left, Expression right) { + // x + 0 -> x + if (isZero(right)) + return left.clone(); + // 0 + x -> x + if (isZero(left)) + return right.clone(); + // x + (-x) -> 0 + if (isNegation(right) && left.equals(negatedExpression(right))) + return new LiteralInt(0); + // (-x) + x -> 0 + if (isNegation(left) && negatedExpression(left).equals(right)) + return new LiteralInt(0); + // x + (-y) -> x - y + if (isNegation(right)) + return new BinaryExpression(left.clone(), "-", negatedExpression(right).clone()); + return null; + } + + /** + * Applies subtraction identities involving zero, same operands, and unary negation + */ + private static Expression simplifySubtraction(Expression left, Expression right) { + // x - 0 -> x + if (isZero(right)) + return left.clone(); + // 0 - x -> -x + if (isZero(left)) + return new UnaryExpression("-", right.clone()); + // x - x -> 0 + if (left.equals(right)) + return new LiteralInt(0); + // x - (-y) -> x + y + if (isNegation(right)) + return new BinaryExpression(left.clone(), "+", negatedExpression(right).clone()); + return null; + } + + /** + * Applies multiplication identities involving one and zero + */ + private static Expression simplifyMultiplication(Expression left, Expression right) { + // x * 1 -> x + if (isOne(right)) + return left.clone(); + // 1 * x -> x + if (isOne(left)) + return right.clone(); + // x * 0 -> 0 + if (isZero(right)) + return right.clone(); + // 0 * x -> 0 + if (isZero(left)) + return left.clone(); + return null; + } + + /** + * Applies division identities, using prior non-zero premises when needed + */ + private static Expression simplifyDivision(Expression left, Expression right, List nonZeroExpressions) { + // x / 1 -> x + if (isOne(right)) + return left.clone(); + // 0 / x -> 0 (x != 0) + if (isZero(left) && isNonZero(right, nonZeroExpressions)) + return left.clone(); + // x / x -> 1 (x != 0) + if (left.equals(right) && isNonZero(right, nonZeroExpressions)) + return new LiteralInt(1); + return null; + } + + /** + * Applies modulo identities, using prior non-zero premises when needed + */ + private static Expression simplifyModulo(Expression left, Expression right, List nonZeroExpressions) { + // x % 1 -> 0 + if (isOne(right)) + return new LiteralInt(0); + // x % x -> 0 (x != 0) + if (left.equals(right) && isNonZero(right, nonZeroExpressions)) + return new LiteralInt(0); + return null; + } + + /** + * Records direct non-zero premises from equalities and inequalities + */ + private static void addNonZeroExpression(Expression expression, List nonZeroExpressions) { + if (!(expression instanceof BinaryExpression binary)) + return; + + Expression left = binary.getFirstOperand(); + Expression right = binary.getSecondOperand(); + if ("!=".equals(binary.getOperator())) { + // x != 0 -> x is non-zero + if (isZero(right)) + nonZeroExpressions.add(left.clone()); + // 0 != x -> x is non-zero + if (isZero(left)) + nonZeroExpressions.add(right.clone()); + } else if ("==".equals(binary.getOperator())) { + // x == n && n != 0 -> x is non-zero + if (isNumericLiteral(right) && !isZero(right)) + nonZeroExpressions.add(left.clone()); + // n == x && n != 0 -> x is non-zero + if (isNumericLiteral(left) && !isZero(left)) + nonZeroExpressions.add(right.clone()); + } + } + + /** + * Checks whether a previous premise recorded an expression as non-zero + */ + private static boolean isNonZero(Expression expression, List nonZeroExpressions) { + return nonZeroExpressions.stream().anyMatch(e -> e.equals(expression)); + } + + /** + * Checks whether an expression is a numeric zero literal + */ + private static boolean isZero(Expression expression) { + if (expression instanceof LiteralInt literal) + return literal.getValue() == 0; + if (expression instanceof LiteralReal literal) + return literal.getValue() == 0.0; + return false; + } + + /** + * Checks whether an expression is a numeric literal + */ + private static boolean isNumericLiteral(Expression expression) { + return expression instanceof LiteralInt || expression instanceof LiteralReal; + } + + /** + * Checks whether an expression is a numeric one literal + */ + private static boolean isOne(Expression expression) { + if (expression instanceof LiteralInt literal) + return literal.getValue() == 1; + if (expression instanceof LiteralReal literal) + return literal.getValue() == 1.0; + return false; + } + + /** + * Checks whether an expression is unary negation + */ + private static boolean isNegation(Expression expression) { + return expression instanceof UnaryExpression unary && "-".equals(unary.getOp()); + } + + /** + * Returns the operand of a unary negation expression + */ + private static Expression negatedExpression(Expression expression) { + return ((UnaryExpression) expression).getExpression(); + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCFolding.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCFolding.java index e74b0e46..a8927f2b 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCFolding.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCFolding.java @@ -28,8 +28,7 @@ public static VCImplication apply(VCImplication implication) { Expression expression = implication.getRefinement().getExpression(); Expression folded = fold(expression); if (!expression.equals(folded)) { - VCImplication result = new SimplifiedVCImplication(implication, new Predicate(folded), - implication.getOrigin()); + VCImplication result = new SimplifiedVCImplication(implication, new Predicate(folded), implication); result.setNext(implication.getNext() == null ? null : implication.getNext().clone()); return result; } diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCSimplification.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCSimplification.java index f87b1081..6854b547 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCSimplification.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCSimplification.java @@ -1,5 +1,8 @@ package liquidjava.rj_language.opt; +import java.util.List; +import java.util.function.UnaryOperator; + import liquidjava.processor.VCImplication; /** @@ -7,6 +10,9 @@ */ public class VCSimplification { + private static final List> PASSES = List.of(VCSubstitution::apply, VCFolding::apply, + VCArithmeticSimplification::apply); + /** * Applies all available simplification steps to a VC chain until a fixed point is reached */ @@ -31,11 +37,11 @@ public static VCImplication simplifyOnce(VCImplication implication) { if (implication == null) return null; - // first try to apply substitution, then folding - VCImplication substituted = VCSubstitution.apply(implication); - if (!implication.equals(substituted)) - return substituted; - - return VCFolding.apply(implication); + for (UnaryOperator pass : PASSES) { + VCImplication simplified = pass.apply(implication); + if (!implication.equals(simplified)) + return simplified; + } + return implication; } } diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCArithmeticSimplificationTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCArithmeticSimplificationTest.java new file mode 100644 index 00000000..69bf921e --- /dev/null +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCArithmeticSimplificationTest.java @@ -0,0 +1,131 @@ +package liquidjava.rj_language.opt; + +import static liquidjava.utils.VCTestUtils.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; + +import liquidjava.processor.SimplifiedVCImplication; +import liquidjava.processor.VCImplication; +import org.junit.jupiter.api.Test; + +class VCArithmeticSimplificationTest { + + @Test + void applyReturnsNullForNullImplication() { + assertNull(VCArithmeticSimplification.apply(null)); + } + + @Test + void simplifiesAdditiveIdentities() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x + 0 > 0"), + chain(expect("x > 0", "x + 0 > 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("0 + x > 0"), + chain(expect("x > 0", "0 + x > 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x - 0 > 0"), + chain(expect("x > 0", "x - 0 > 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("0 - x > 0"), + chain(expect("-x > 0", "0 - x > 0"))); + } + + @Test + void simplifiesNegatedAdditionAndSubtraction() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x + -x == 0"), + chain(expect("0 == 0", "x + -x == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("-x + x == 0"), + chain(expect("0 == 0", "-x + x == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x - x == 0"), + chain(expect("0 == 0", "x - x == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("--x == x"), + chain(expect("x == x", "-(-x) == x"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x + -y == 0"), + chain(expect("x - y == 0", "x + -y == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x - -y == 0"), + chain(expect("x + y == 0", "x - -y == 0"))); + } + + @Test + void simplifiesMultiplicativeIdentities() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x * 1 > 0"), + chain(expect("x > 0", "x * 1 > 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("1 * x > 0"), + chain(expect("x > 0", "1 * x > 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x * 0 == 0"), + chain(expect("0 == 0", "x * 0 == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("0 * x == 0"), + chain(expect("0 == 0", "0 * x == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x / 1 > 0"), + chain(expect("x > 0", "x / 1 > 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x % 1 == 0"), + chain(expect("0 == 0", "x % 1 == 0"))); + } + + @Test + void simplifiesGuardedDivisionAndModuloIdentities() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x != 0", "0 / x == 0"), + chain(expect("x != 0", "x != 0"), expect("0 == 0", "0 / x == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x != 0", "x / x == 1"), + chain(expect("x != 0", "x != 0"), expect("1 == 1", "x / x == 1"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("0 != x", "x % x == 0"), + chain(expect("0 != x", "0 != x"), expect("0 == 0", "x % x == 0"))); + } + + @Test + void simplifiesGuardedDivisionAndModuloIdentitiesWhenEqualityImpliesNonZero() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x == 1", "0 / x == 0"), + chain(expect("x == 1", "x == 1"), expect("0 == 0", "0 / x == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("1 == x", "x / x == 1"), + chain(expect("1 == x", "1 == x"), expect("1 == 1", "x / x == 1"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x == 1", "x % x == 0"), + chain(expect("x == 1", "x == 1"), expect("0 == 0", "x % x == 0"))); + } + + @Test + void leavesUnguardedDivisionAndModuloIdentitiesUnchanged() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("0 / x == 0"), + chain(expect("0 / x == 0", "0 / x == 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x / x == 1"), + chain(expect("x / x == 1", "x / x == 1"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x % x == 0"), + chain(expect("x % x == 0", "x % x == 0"))); + } + + @Test + void simplifiesOnlyFirstArithmeticIdentity() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("x + 0 + 1 > 0"), + chain(expect("x + 1 > 0", "x + 0 + 1 > 0"))); + } + + @Test + void simplifiesTernaryExpressionsInConditionThenElseOrder() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("(flag + 0 > 0 ? x + 0 : y + 0) > 0"), + chain(expect("(flag > 0 ? x + 0 : y + 0) > 0", "(flag + 0 > 0 ? x + 0 : y + 0) > 0")), + chain(expect("(flag > 0 ? x : y + 0) > 0", "(flag > 0 ? x + 0 : y + 0) > 0")), + chain(expect("(flag > 0 ? x : y) > 0", "(flag > 0 ? x : y + 0) > 0"))); + } + + @Test + void simplifiesGroupedExpressionsAndLeavesUnchangedGroupsAlone() { + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("(x + 0) * y > 0"), + chain(expect("x * y > 0", "(x + 0) * y > 0"))); + assertSimplificationSteps(VCArithmeticSimplification::apply, vc("(x) > 0"), chain(expect("x > 0", "x > 0"))); + } + + @Test + void recordsOriginWhenSimplifyingLaterImplication() { + VCImplication implication = vc("x > 0", "y + 0 > x"); + + VCImplication result = assertSimplificationSteps(VCArithmeticSimplification::apply, implication, + chain(expect("x > 0", "x > 0"), expect("y > x", "y + 0 > x"))); + + SimplifiedVCImplication simplifiedNext = assertInstanceOf(SimplifiedVCImplication.class, result.getNext()); + assertEquals("y + 0 > x", simplifiedNext.getOrigin().getRefinement().getExpression().toDisplayString()); + } + + @Test + void recordsCurrentImplicationAsOriginWhenSimplifyingExistingSimplifiedImplication() { + VCImplication substituted = VCSubstitution.apply(vc("∀x:int. x == y + 0", "x > 0")); + + assertSimplificationSteps(VCArithmeticSimplification::apply, substituted, chain(expect("y > 0", "y + 0 > 0"))); + } +} diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCFoldingTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCFoldingTest.java index 53f51404..b73bd8ec 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCFoldingTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCFoldingTest.java @@ -26,7 +26,7 @@ void foldsIntegerArithmeticAndComparisons() { VCImplication implication = vc("1 + 2 == 3"); assertSimplificationSteps(VCFolding::apply, implication, chain(expect("3 == 3", "1 + 2 == 3")), - chain(expect("true", "1 + 2 == 3"))); + chain(expect("true", "3 == 3"))); assertSimplificationSteps(VCFolding::apply, vc("4 > 7"), chain(expect("false", "4 > 7"))); } @@ -36,9 +36,9 @@ void foldsRealAndMixedNumericExpressions() { VCImplication mixedArithmetic = vc("2 + 0.5 > 2"); assertSimplificationSteps(VCFolding::apply, realArithmetic, chain(expect("3.5 == 3.5", "1.5 + 2.0 == 3.5")), - chain(expect("true", "1.5 + 2.0 == 3.5"))); + chain(expect("true", "3.5 == 3.5"))); assertSimplificationSteps(VCFolding::apply, mixedArithmetic, chain(expect("2.5 > 2", "2 + 0.5 > 2")), - chain(expect("true", "2 + 0.5 > 2"))); + chain(expect("true", "2.5 > 2"))); } @Test @@ -55,6 +55,27 @@ void leavesRealDivisionAndModuloByZeroUnchanged() { chain(expect("4.0 % 0.0 == 0.0", "4.0 % 0.0 == 0.0"))); } + @Test + void foldsIntegerDivisionTowardZeroForNegativeResults() { + VCImplication implication = vc("(2 - 7) / 2 == -2"); + + assertSimplificationSteps(VCFolding::apply, implication, + chain(expect("(2 - 7) / 2 == -2", "(2 - 7) / 2 == -2")), + chain(expect("-5 / 2 == -2", "(2 - 7) / 2 == -2")), chain(expect("-2 == -2", "-5 / 2 == -2")), + chain(expect("-2 == -2", "-2 == -2")), chain(expect("true", "-2 == -2"))); + } + + @Test + void foldsIntegerModuloWithJavaSignedRemainder() { + VCImplication negativeDividend = vc("-5 % 2 < 0"); + VCImplication negativeDivisor = vc("5 % -2 > 0"); + + assertSimplificationSteps(VCFolding::apply, negativeDividend, chain(expect("-5 % 2 < 0", "-5 % 2 < 0")), + chain(expect("-1 < 0", "-5 % 2 < 0")), chain(expect("true", "-1 < 0"))); + assertSimplificationSteps(VCFolding::apply, negativeDivisor, chain(expect("5 % -2 > 0", "5 % -2 > 0")), + chain(expect("1 > 0", "5 % -2 > 0")), chain(expect("true", "1 > 0"))); + } + @Test void foldsBooleanBinaryExpressions() { assertSimplificationSteps(VCFolding::apply, vc("true && false"), chain(expect("false", "true && false"))); @@ -100,7 +121,7 @@ void foldsIteBranchesBeforeComparingThem() { VCImplication implication = vc("cond ? 1 + 2 : 3"); assertSimplificationSteps(VCFolding::apply, implication, chain(expect("cond ? 3 : 3", "cond ? 1 + 2 : 3")), - chain(expect("3", "cond ? 1 + 2 : 3"))); + chain(expect("3", "cond ? 3 : 3"))); } @Test @@ -127,7 +148,7 @@ void foldsResolvedEnumLiterals() { new Predicate(new BinaryExpression(limit, "==", new LiteralInt(3)))); assertSimplificationSteps(VCFolding::apply, implication, chain(expect("3 == 3", "Config.LIMIT == 3")), - chain(expect("true", "Config.LIMIT == 3"))); + chain(expect("true", "3 == 3"))); } @Test @@ -139,15 +160,15 @@ void foldsResolvedEnumLiteralsInsideLargerExpression() { new Predicate(new BinaryExpression(arithmetic, "==", new LiteralInt(5)))); assertSimplificationSteps(VCFolding::apply, implication, chain(expect("3 + 2 == 5", "Config.LIMIT + 2 == 5")), - chain(expect("5 == 5", "Config.LIMIT + 2 == 5")), chain(expect("true", "Config.LIMIT + 2 == 5"))); + chain(expect("5 == 5", "3 + 2 == 5")), chain(expect("true", "5 == 5"))); } @Test - void preservesOriginFromExistingSimplifiedImplication() { + void recordsCurrentImplicationAsOriginWhenFoldingExistingSimplifiedImplication() { VCImplication substituted = VCSubstitution.apply(vc("∀x:int. x == 1", "x + 1 + 2 > 0")); - assertSimplificationSteps(VCFolding::apply, substituted, chain(expect("2 + 2 > 0", "∀x:int. x + 1 + 2 > 0")), - chain(expect("4 > 0", "∀x:int. x + 1 + 2 > 0")), chain(expect("true", "∀x:int. x + 1 + 2 > 0"))); + assertSimplificationSteps(VCFolding::apply, substituted, chain(expect("2 + 2 > 0", "1 + 1 + 2 > 0")), + chain(expect("4 > 0", "2 + 2 > 0")), chain(expect("true", "4 > 0"))); } @Test @@ -169,13 +190,13 @@ void recordsOriginWhenFoldingLaterImplication() { chain(expect("x > 0", "x > 0"), expect("3 > 0", "1 + 2 > 0"))); SimplifiedVCImplication simplifiedNext = assertInstanceOf(SimplifiedVCImplication.class, result.getNext()); - assertEquals("1 + 2 > 0", simplifiedNext.getOrigin().getRefinement().toString()); + assertEquals("1 + 2 > 0", simplifiedNext.getOrigin().getRefinement().getExpression().toDisplayString()); result = assertSimplificationSteps(VCFolding::apply, result, - chain(expect("x > 0", "x > 0"), expect("true", "1 + 2 > 0"))); + chain(expect("x > 0", "x > 0"), expect("true", "3 > 0"))); simplifiedNext = assertInstanceOf(SimplifiedVCImplication.class, result.getNext()); - assertEquals("1 + 2 > 0", simplifiedNext.getOrigin().getRefinement().toString()); + assertEquals("3 > 0", simplifiedNext.getOrigin().getRefinement().getExpression().toDisplayString()); } } diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCImplicationGenerator.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCImplicationGenerator.java index 0ebe0e42..104a619c 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCImplicationGenerator.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCImplicationGenerator.java @@ -21,7 +21,7 @@ public VCImplicationGenerator() { @Override public VCImplication generate(SourceOfRandomness random, GenerationStatus status) { - return switch (random.nextInt(0, 9)) { + return switch (random.nextInt(0, 11)) { case 0 -> vc(substitution(random, "x"), comparison(random, "x")); case 1 -> vc(reverseSubstitution(random, "x"), comparison(random, "x")); case 2 -> vc(nonSubstitution(random, "x"), substitution(random, "y"), comparison(random, "y")); @@ -31,6 +31,8 @@ public VCImplication generate(SourceOfRandomness random, GenerationStatus status case 6 -> vc(foldableBoolean(random), comparison(random, "x")); case 7 -> vc(foldableIte(random)); case 8 -> vc(adjacentConstants(random) + " " + comparisonOperator(random) + " " + intLiteral(random)); + case 9 -> vc(arithmeticIdentity(random)); + case 10 -> guardedArithmeticIdentity(random); default -> vc(substitution(random, "x"), substitution(random, "y"), foldableComparison(random)); }; } @@ -94,6 +96,34 @@ private static String adjacentConstants(SourceOfRandomness random) { return variable + " " + signed(left) + " " + signed(right); } + private static String arithmeticIdentity(SourceOfRandomness random) { + String var = FREE_VARS[random.nextInt(0, FREE_VARS.length - 1)]; + String other = FREE_VARS[random.nextInt(0, FREE_VARS.length - 1)]; + return switch (random.nextInt(0, 9)) { + case 0 -> var + " + 0 " + comparisonOperator(random) + " " + intLiteral(random); + case 1 -> "0 + " + var + " " + comparisonOperator(random) + " " + intLiteral(random); + case 2 -> var + " - 0 " + comparisonOperator(random) + " " + intLiteral(random); + case 3 -> "0 - " + var + " " + comparisonOperator(random) + " " + intLiteral(random); + case 4 -> var + " - " + var + " == 0"; + case 5 -> var + " * 1 " + comparisonOperator(random) + " " + intLiteral(random); + case 6 -> "1 * " + var + " " + comparisonOperator(random) + " " + intLiteral(random); + case 7 -> var + " * 0 == 0"; + case 8 -> var + " / 1 " + comparisonOperator(random) + " " + intLiteral(random); + default -> var + " + -" + other + " " + comparisonOperator(random) + " " + intLiteral(random); + }; + } + + private static VCImplication guardedArithmeticIdentity(SourceOfRandomness random) { + String var = FREE_VARS[random.nextInt(0, FREE_VARS.length - 1)]; + String guard = random.nextBoolean() ? var + " != 0" : "0 != " + var; + String use = switch (random.nextInt(0, 2)) { + case 0 -> "0 / " + var + " == 0"; + case 1 -> var + " / " + var + " == 1"; + default -> var + " % " + var + " == 0"; + }; + return vc(guard, use); + } + private static String comparisonOperator(SourceOfRandomness random) { return COMPARISON_OPS[random.nextInt(0, COMPARISON_OPS.length - 1)]; } diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationTest.java index a2ce7c82..ffe65220 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationTest.java @@ -23,8 +23,8 @@ void simplifyOnceAppliesSubstitutionBeforeFolding() { VCImplication implication = vc("∀x:int. x == 1 + 2", "x > 2"); assertSimplificationSteps(VCSimplification::simplifyOnce, implication, - chain(expect("1 + 2 > 2", "∀x:int. x > 2")), chain(expect("3 > 2", "∀x:int. x > 2")), - chain(expect("true", "∀x:int. x > 2"))); + chain(expect("1 + 2 > 2", "∀x:int. x > 2")), chain(expect("3 > 2", "1 + 2 > 2")), + chain(expect("true", "3 > 2"))); } @Test @@ -32,8 +32,8 @@ void simplifyOnceDoesNotFoldAfterSubstitutionInSameStep() { VCImplication implication = vc("∀x:int. x == 1 + 2", "x == 3"); assertSimplificationSteps(VCSimplification::simplifyOnce, implication, - chain(expect("1 + 2 == 3", "∀x:int. x == 3")), chain(expect("3 == 3", "∀x:int. x == 3")), - chain(expect("true", "∀x:int. x == 3"))); + chain(expect("1 + 2 == 3", "∀x:int. x == 3")), chain(expect("3 == 3", "1 + 2 == 3")), + chain(expect("true", "3 == 3"))); } @Test @@ -41,7 +41,22 @@ void simplifyOnceAppliesFoldingWhenNoSubstitutionIsAvailable() { VCImplication implication = vc("1 + 2 > 2"); assertSimplificationSteps(VCSimplification::simplifyOnce, implication, chain(expect("3 > 2", "1 + 2 > 2")), - chain(expect("true", "1 + 2 > 2"))); + chain(expect("true", "3 > 2"))); + } + + @Test + void simplifyOnceAppliesFoldingBeforeArithmeticSimplification() { + VCImplication implication = vc("1 + 2 + x + 0 > 0"); + + assertSimplificationSteps(VCSimplification::simplifyOnce, implication, + chain(expect("3 + x + 0 > 0", "1 + 2 + x + 0 > 0"))); + } + + @Test + void simplifyOnceAppliesArithmeticWhenNoSubstitutionOrFoldingIsAvailable() { + VCImplication implication = vc("x + 0 > 0"); + + assertSimplificationSteps(VCSimplification::simplifyOnce, implication, chain(expect("x > 0", "x + 0 > 0"))); } @Test @@ -49,8 +64,8 @@ void simplifyKeepsApplyingStepsUntilFixedPoint() { VCImplication implication = vc("∀x:int. x == 1 + 2", "x + 1 > 3"); assertSimplificationSteps(VCSimplification::simplifyOnce, implication, - chain(expect("1 + 2 + 1 > 3", "∀x:int. x + 1 > 3")), chain(expect("3 + 1 > 3", "∀x:int. x + 1 > 3")), - chain(expect("4 > 3", "∀x:int. x + 1 > 3")), chain(expect("true", "∀x:int. x + 1 > 3"))); + chain(expect("1 + 2 + 1 > 3", "∀x:int. x + 1 > 3")), chain(expect("3 + 1 > 3", "1 + 2 + 1 > 3")), + chain(expect("4 > 3", "3 + 1 > 3")), chain(expect("true", "4 > 3"))); } @Test @@ -59,8 +74,8 @@ void simplifyAppliesMultipleSubstitutionsBeforeReachingFixedPoint() { assertSimplificationSteps(VCSimplification::simplifyOnce, implication, chain(expect("y == 3 + 1", "∀x:int. y == x + 1"), expect("y > 3", "∀x:int. y > x")), - chain(expect("3 + 1 > 3", "∀y:int. y > x")), chain(expect("4 > 3", "∀y:int. y > x")), - chain(expect("true", "∀y:int. y > x"))); + chain(expect("3 + 1 > 3", "∀y:int. y > x")), chain(expect("4 > 3", "3 + 1 > 3")), + chain(expect("true", "4 > 3"))); } @Test @@ -71,8 +86,8 @@ void simplifyAppliesLongSubstitutionChainBeforeReachingFixedPoint() { chain(expect("y == 1 + 1", "∀x:int. y == x + 1"), expect("z == y + 1", "∀z:int. z == y + 1"), expect("z == 3", "z == 3")), chain(expect("z == 1 + 1 + 1", "∀y:int. z == y + 1"), expect("z == 3", "z == 3")), - chain(expect("1 + 1 + 1 == 3", "∀z:int. z == 3")), chain(expect("2 + 1 == 3", "∀z:int. z == 3")), - chain(expect("3 == 3", "∀z:int. z == 3")), chain(expect("true", "∀z:int. z == 3"))); + chain(expect("1 + 1 + 1 == 3", "∀z:int. z == 3")), chain(expect("2 + 1 == 3", "1 + 1 + 1 == 3")), + chain(expect("3 == 3", "2 + 1 == 3")), chain(expect("true", "3 == 3"))); } @Test @@ -81,9 +96,8 @@ void simplifyCombinesSubstitutionAndNestedFoldingAcrossFixedPoint() { assertSimplificationSteps(VCSimplification::simplifyOnce, implication, chain(expect("y == 1 + 2", "∀x:int. y == x + 2"), expect("y - 1 == 2", "y - 1 == 2")), - chain(expect("1 + 2 - 1 == 2", "∀y:int. y - 1 == 2")), - chain(expect("3 - 1 == 2", "∀y:int. y - 1 == 2")), chain(expect("2 == 2", "∀y:int. y - 1 == 2")), - chain(expect("true", "∀y:int. y - 1 == 2"))); + chain(expect("1 + 2 - 1 == 2", "∀y:int. y - 1 == 2")), chain(expect("3 - 1 == 2", "1 + 2 - 1 == 2")), + chain(expect("2 == 2", "3 - 1 == 2")), chain(expect("true", "2 == 2"))); } @Test diff --git a/liquidjava-verifier/src/test/java/liquidjava/utils/VCTestUtils.java b/liquidjava-verifier/src/test/java/liquidjava/utils/VCTestUtils.java index ba82b81e..61e46da4 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/utils/VCTestUtils.java +++ b/liquidjava-verifier/src/test/java/liquidjava/utils/VCTestUtils.java @@ -53,7 +53,7 @@ public static void assertSimplifiedVC(VCImplication implication, ExpectedSimplif ExpectedSimplifiedVCImplication expectedPredicate = expected[i]; assertEquals(Predicate.class, current.getRefinement().getClass(), "Expected simplified refinement at implication " + i + " to be a plain Predicate"); - assertEquals(expectedPredicate.simplified(), current.getRefinement().toString(), + assertEquals(expectedPredicate.simplified(), formatRefinement(current), "Unexpected simplified expression at implication " + i); if (expectedPredicate.origin() != null) assertEquals(expectedPredicate.origin(), formatOrigin(current.getOrigin()), @@ -83,8 +83,12 @@ public static ExpectedSimplifiedVCImplication expect(String simplified, String o private static String formatOrigin(VCImplication origin) { if (!origin.hasBinder()) - return origin.getRefinement().toString(); - return "∀" + origin.getName() + ":" + origin.getType().getQualifiedName() + ". " + origin.getRefinement(); + return formatRefinement(origin); + return "∀" + origin.getName() + ":" + origin.getType().getQualifiedName() + ". " + formatRefinement(origin); + } + + private static String formatRefinement(VCImplication implication) { + return implication.getRefinement().getExpression().toDisplayString(); } public record ExpectedSimplifiedVCImplication(String simplified, String origin) {