33import liquidjava .rj_language .ast .BinaryExpression ;
44import liquidjava .rj_language .ast .Expression ;
55import liquidjava .rj_language .ast .LiteralBoolean ;
6+ import liquidjava .rj_language .ast .UnaryExpression ;
67import liquidjava .rj_language .opt .derivation_node .BinaryDerivationNode ;
78import liquidjava .rj_language .opt .derivation_node .DerivationNode ;
9+ import liquidjava .rj_language .opt .derivation_node .UnaryDerivationNode ;
810import liquidjava .rj_language .opt .derivation_node .ValDerivationNode ;
911
1012public class ExpressionSimplifier {
@@ -15,12 +17,13 @@ public class ExpressionSimplifier {
1517 */
1618 public static ValDerivationNode simplify (Expression exp ) {
1719 ValDerivationNode fixedPoint = simplifyToFixedPoint (null , exp );
18- return simplifyValDerivationNode (fixedPoint );
20+ ValDerivationNode simplified = simplifyValDerivationNode (fixedPoint );
21+ return unwrapBooleanLiterals (simplified );
1922 }
2023
2124 /**
2225 * Recursively applies propagation and folding until the expression stops changing (fixed point) Stops early if the
23- * expression simplifies to 'true' , which means we've simplified too much
26+ * expression simplifies to a boolean literal , which means we've simplified too much
2427 */
2528 private static ValDerivationNode simplifyToFixedPoint (ValDerivationNode current , Expression prevExp ) {
2629 // apply propagation and folding
@@ -34,6 +37,11 @@ private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current,
3437 return current ;
3538 }
3639
40+ // prevent oversimplification
41+ if (current != null && currExp instanceof LiteralBoolean && !(current .getValue () instanceof LiteralBoolean )) {
42+ return current ;
43+ }
44+
3745 // continue simplifying
3846 return simplifyToFixedPoint (simplified , simplified .getValue ());
3947 }
@@ -76,8 +84,12 @@ private static ValDerivationNode simplifyValDerivationNode(ValDerivationNode nod
7684
7785 // return the conjunction with simplified children
7886 Expression newValue = new BinaryExpression (leftSimplified .getValue (), "&&" , rightSimplified .getValue ());
79- DerivationNode newOrigin = new BinaryDerivationNode (leftSimplified , rightSimplified , "&&" );
80- return new ValDerivationNode (newValue , newOrigin );
87+ // only create origin if at least one child has a meaningful origin
88+ if (leftSimplified .getOrigin () != null || rightSimplified .getOrigin () != null ) {
89+ DerivationNode newOrigin = new BinaryDerivationNode (leftSimplified , rightSimplified , "&&" );
90+ return new ValDerivationNode (newValue , newOrigin );
91+ }
92+ return new ValDerivationNode (newValue , null );
8193 }
8294 // no simplification
8395 return node ;
@@ -114,4 +126,61 @@ private static boolean isRedundant(Expression exp) {
114126 }
115127 return false ;
116128 }
129+
130+ /**
131+ * Recursively traverses the derivation tree and replaces boolean literals with the expressions that produced them,
132+ * but only when at least one operand in the derivation is non-boolean. e.g. "x == true" where true came from "1 >
133+ * 0" becomes "x == 1 > 0"
134+ */
135+ private static ValDerivationNode unwrapBooleanLiterals (ValDerivationNode node ) {
136+ Expression value = node .getValue ();
137+ DerivationNode origin = node .getOrigin ();
138+
139+ if (origin == null )
140+ return node ;
141+
142+ // unwrap binary expressions
143+ if (value instanceof BinaryExpression binExp && origin instanceof BinaryDerivationNode binOrigin ) {
144+ ValDerivationNode left = unwrapBooleanLiterals (binOrigin .getLeft ());
145+ ValDerivationNode right = unwrapBooleanLiterals (binOrigin .getRight ());
146+ if (left != binOrigin .getLeft () || right != binOrigin .getRight ()) {
147+ Expression newValue = new BinaryExpression (left .getValue (), binExp .getOperator (), right .getValue ());
148+ return new ValDerivationNode (newValue , new BinaryDerivationNode (left , right , binOrigin .getOp ()));
149+ }
150+ return node ;
151+ }
152+
153+ // unwrap unary expressions
154+ if (value instanceof UnaryExpression unaryExp && origin instanceof UnaryDerivationNode unaryOrigin ) {
155+ ValDerivationNode operand = unwrapBooleanLiterals (unaryOrigin .getOperand ());
156+ if (operand != unaryOrigin .getOperand ()) {
157+ Expression newValue = new UnaryExpression (unaryExp .getOp (), operand .getValue ());
158+ return new ValDerivationNode (newValue , new UnaryDerivationNode (operand , unaryOrigin .getOp ()));
159+ }
160+ return node ;
161+ }
162+
163+ // boolean literal with binary origin: unwrap if at least one child is non-boolean
164+ if (value instanceof LiteralBoolean && origin instanceof BinaryDerivationNode binOrigin ) {
165+ ValDerivationNode left = unwrapBooleanLiterals (binOrigin .getLeft ());
166+ ValDerivationNode right = unwrapBooleanLiterals (binOrigin .getRight ());
167+ if (!(left .getValue () instanceof LiteralBoolean ) || !(right .getValue () instanceof LiteralBoolean )) {
168+ Expression newValue = new BinaryExpression (left .getValue (), binOrigin .getOp (), right .getValue ());
169+ return new ValDerivationNode (newValue , new BinaryDerivationNode (left , right , binOrigin .getOp ()));
170+ }
171+ return node ;
172+ }
173+
174+ // boolean literal with unary origin: unwrap if operand is non-boolean
175+ if (value instanceof LiteralBoolean && origin instanceof UnaryDerivationNode unaryOrigin ) {
176+ ValDerivationNode operand = unwrapBooleanLiterals (unaryOrigin .getOperand ());
177+ if (!(operand .getValue () instanceof LiteralBoolean )) {
178+ Expression newValue = new UnaryExpression (unaryOrigin .getOp (), operand .getValue ());
179+ return new ValDerivationNode (newValue , new UnaryDerivationNode (operand , unaryOrigin .getOp ()));
180+ }
181+ return node ;
182+ }
183+
184+ return node ;
185+ }
117186}
0 commit comments