Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public static <T> void printError(CtElement var, String moreInfo, Predicate expe
// all message
sb.append(sbtitle.toString() + "\n\n");
sb.append("Type expected:" + expectedType.toString() + "\n");
sb.append("Refinement found:" + cSMT.toString() + "\n");
sb.append("Refinement found:\n" + cSMT.simplify().getValue() + "\n");
sb.append(printMap(map));
sb.append("Location: " + var.getPosition() + "\n");
sb.append("______________________________________________________\n");
Expand Down Expand Up @@ -181,7 +181,7 @@ public static void printCostumeError(CtElement element, String msg, ErrorEmitter
sb.append(element + "\n\n");
sb.append("Location: " + element.getPosition() + "\n");
sb.append("______________________________________________________\n");

errorl.addError(s, sb.toString(), element.getPosition(), 1);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ private void createStateSet(CtNewArray<String> e, int set, CtElement element) {
CtLiteral<String> s = (CtLiteral<String>) ce;
String f = s.getValue();
if (Character.isUpperCase(f.charAt(0))) {
ErrorHandler.printCostumeError(s, "State name must start with lowercase in '" + f + "'", errorEmitter);
Comment thread
rcosta358 marked this conversation as resolved.
Outdated
ErrorHandler.printCostumeError(s, "State name must start with lowercase in '" + f + "'",
errorEmitter);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import liquidjava.rj_language.ast.LiteralReal;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.ast.Var;
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
import liquidjava.rj_language.opt.ExpressionSimplifier;
import liquidjava.rj_language.parsing.ParsingException;
import liquidjava.rj_language.parsing.RefinementsParser;
import liquidjava.utils.Utils;
Expand Down Expand Up @@ -212,6 +215,10 @@ public Expression getExpression() {
return exp;
}

public ValDerivationNode simplify() {
return ExpressionSimplifier.simplify(exp.clone());
}

public static Predicate createConjunction(Predicate c1, Predicate c2) {
return new Predicate(new BinaryExpression(c1.getExpression(), Utils.AND, c2.getExpression()));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package liquidjava.rj_language.ast;

import com.microsoft.z3.Expr;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import com.microsoft.z3.Expr;
Comment thread
rcosta358 marked this conversation as resolved.

import liquidjava.processor.context.Context;
import liquidjava.processor.facade.AliasDTO;
import liquidjava.rj_language.ast.typing.TypeInfer;
Expand Down Expand Up @@ -47,6 +49,10 @@ public void setChild(int index, Expression element) {
children.set(index, element);
}

public boolean isLiteral() {
return this instanceof LiteralInt || this instanceof LiteralReal || this instanceof LiteralBoolean;
}

/**
* Substitutes the expression first given expression by the second
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public String toString() {
return Integer.toString(value);
}

public int getValue() {
return value;
}

@Override
public void getVariableNames(List<String> toAdd) {
// end leaf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public String toString() {
return Double.toString(value);
}

public double getValue() {
return value;
}

@Override
public void getVariableNames(List<String> toAdd) {
// end leaf
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package liquidjava.rj_language.opt;

import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.GroupExpression;
import liquidjava.rj_language.ast.LiteralBoolean;
import liquidjava.rj_language.ast.LiteralInt;
import liquidjava.rj_language.ast.LiteralReal;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;

public class ConstantFolding {

public static ValDerivationNode fold(ValDerivationNode node) {
Expression exp = node.getValue();
if (exp instanceof BinaryExpression) {
return foldBinary(node);
}
if (exp instanceof UnaryExpression) {
return foldUnary(node);
}
if (exp instanceof GroupExpression) {
GroupExpression group = (GroupExpression) exp;
if (group.getChildren().size() == 1) {
return fold(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()));
}
}
return node;
}

private static ValDerivationNode foldBinary(ValDerivationNode node) {
BinaryExpression binExp = (BinaryExpression) node.getValue();
DerivationNode parent = node.getOrigin();

// fold child nodes
ValDerivationNode leftNode;
ValDerivationNode rightNode;
if (parent instanceof BinaryDerivationNode) {
// has origin (from constant propagation)
BinaryDerivationNode binaryOrigin = (BinaryDerivationNode) parent;
leftNode = fold(binaryOrigin.getLeft());
rightNode = fold(binaryOrigin.getRight());
} else {
// no origin
leftNode = fold(new ValDerivationNode(binExp.getFirstOperand(), null));
rightNode = fold(new ValDerivationNode(binExp.getSecondOperand(), null));
}

Expression left = leftNode.getValue();
Expression right = rightNode.getValue();
String op = binExp.getOperator();
binExp.setChild(0, left);
binExp.setChild(1, right);

// int and int
if (left instanceof LiteralInt && right instanceof LiteralInt) {
int l = ((LiteralInt) left).getValue();
int r = ((LiteralInt) right).getValue();
Expression res = switch (op) {
case "+" -> new LiteralInt(l + r);
case "-" -> new LiteralInt(l - r);
case "*" -> new LiteralInt(l * r);
case "/" -> r != 0 ? new LiteralInt(l / r) : null;
case "%" -> r != 0 ? new LiteralInt(l % r) : null;
case "<" -> new LiteralBoolean(l < r);
case "<=" -> new LiteralBoolean(l <= r);
case ">" -> new LiteralBoolean(l > r);
case ">=" -> new LiteralBoolean(l >= r);
case "==" -> new LiteralBoolean(l == r);
case "!=" -> new LiteralBoolean(l != r);
default -> null;
};
if (res != null)
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
}
// real and real
else if (left instanceof LiteralReal && right instanceof LiteralReal) {
double l = ((LiteralReal) left).getValue();
double r = ((LiteralReal) right).getValue();
Expression res = switch (op) {
case "+" -> new LiteralReal(l + r);
case "-" -> new LiteralReal(l - r);
case "*" -> new LiteralReal(l * r);
case "/" -> r != 0.0 ? new LiteralReal(l / r) : null;
case "%" -> r != 0.0 ? new LiteralReal(l % r) : null;
case "<" -> new LiteralBoolean(l < r);
case "<=" -> new LiteralBoolean(l <= r);
case ">" -> new LiteralBoolean(l > r);
case ">=" -> new LiteralBoolean(l >= r);
case "==" -> new LiteralBoolean(l == r);
case "!=" -> new LiteralBoolean(l != r);
default -> null;
};
if (res != null)
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
}

// mixed int and real
else if ((left instanceof LiteralInt && right instanceof LiteralReal)
|| (left instanceof LiteralReal && right instanceof LiteralInt)) {
double l = left instanceof LiteralInt ? ((LiteralInt) left).getValue() : ((LiteralReal) left).getValue();
double r = right instanceof LiteralInt ? ((LiteralInt) right).getValue() : ((LiteralReal) right).getValue();
Expression res = switch (op) {
case "+" -> new LiteralReal(l + r);
case "-" -> new LiteralReal(l - r);
case "*" -> new LiteralReal(l * r);
case "/" -> r != 0.0 ? new LiteralReal(l / r) : null;
case "%" -> r != 0.0 ? new LiteralReal(l % r) : null;
case "<" -> new LiteralBoolean(l < r);
case "<=" -> new LiteralBoolean(l <= r);
case ">" -> new LiteralBoolean(l > r);
case ">=" -> new LiteralBoolean(l >= r);
case "==" -> new LiteralBoolean(l == r);
case "!=" -> new LiteralBoolean(l != r);
default -> null;
};
if (res != null)
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
}
// bool and bool
else if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) {
boolean l = ((LiteralBoolean) left).isBooleanTrue();
boolean r = ((LiteralBoolean) right).isBooleanTrue();
Expression res = switch (op) {
case "&&" -> new LiteralBoolean(l && r);
case "||" -> new LiteralBoolean(l || r);
case "-->" -> new LiteralBoolean(!l || r);
case "==" -> new LiteralBoolean(l == r);
case "!=" -> new LiteralBoolean(l != r);
default -> null;
};
if (res != null)
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
}

// no folding
DerivationNode origin = (leftNode.getOrigin() != null || rightNode.getOrigin() != null)
? new BinaryDerivationNode(leftNode, rightNode, op) : null;
return new ValDerivationNode(binExp, origin);
}

private static ValDerivationNode foldUnary(ValDerivationNode node) {
UnaryExpression unaryExp = (UnaryExpression) node.getValue();
DerivationNode parent = node.getOrigin();

// fold child node
ValDerivationNode operandNode;
if (parent instanceof UnaryDerivationNode) {
// has origin (from constant propagation)
UnaryDerivationNode unaryOrigin = (UnaryDerivationNode) parent;
operandNode = fold(unaryOrigin.getOperand());
} else {
// no origin
operandNode = fold(new ValDerivationNode(unaryExp.getChildren().get(0), null));
}
Expression operand = operandNode.getValue();
String operator = unaryExp.getOp();
unaryExp.setChild(0, operand);

// unary not
if ("!".equals(operator) && operand instanceof LiteralBoolean) {
// !true => false, !false => true
boolean value = ((LiteralBoolean) operand).isBooleanTrue();
Expression res = new LiteralBoolean(!value);
return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
}
// unary minus
if ("-".equals(operator)) {
// -(x) => -x
if (operand instanceof LiteralInt) {
Expression res = new LiteralInt(-((LiteralInt) operand).getValue());
return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
}
if (operand instanceof LiteralReal) {
Expression res = new LiteralReal(-((LiteralReal) operand).getValue());
return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
}
}

// no folding
DerivationNode origin = operandNode.getOrigin() != null ? new UnaryDerivationNode(operandNode, operator) : null;
return new ValDerivationNode(unaryExp, origin);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package liquidjava.rj_language.opt;

import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.ast.Var;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;

import java.util.Map;

public class ConstantPropagation {

public static ValDerivationNode propagate(Expression exp) {
Map<String, Expression> substitutions = VariableResolver.resolve(exp);
return propagateRecursive(exp, substitutions);
}

private static ValDerivationNode propagateRecursive(Expression exp, Map<String, Expression> subs) {

// substitute variable
if (exp instanceof Var) {
Var var = (Var) exp;
String name = var.getName();
Expression value = subs.get(name);
// substitution
if (value != null)
return new ValDerivationNode(value.clone(), new VarDerivationNode(name));

// no substitution
return new ValDerivationNode(var, null);
}

// lift unary origin
if (exp instanceof UnaryExpression) {
UnaryExpression unary = (UnaryExpression) exp;
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs);
unary.setChild(0, operand.getValue());

DerivationNode origin = operand.getOrigin() != null ? new UnaryDerivationNode(operand, unary.getOp())
: null;
return new ValDerivationNode(unary, origin);
}

// lift binary origin
if (exp instanceof BinaryExpression) {
BinaryExpression binary = (BinaryExpression) exp;
ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs);
ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs);
binary.setChild(0, left.getValue());
binary.setChild(1, right.getValue());

DerivationNode origin = (left.getOrigin() != null || right.getOrigin() != null)
? new BinaryDerivationNode(left, right, binary.getOperator()) : null;
return new ValDerivationNode(binary, origin);
}

// recursively propagate children
if (exp.hasChildren()) {
Expression propagated = exp.clone();
for (int i = 0; i < exp.getChildren().size(); i++) {
ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs);
propagated.setChild(i, child.getValue());
}
return new ValDerivationNode(propagated, null);
}

// no propagation
return new ValDerivationNode(exp, null);
}
}
Loading