From 98271c101341e4edcf811a28fb3d88561d4900be Mon Sep 17 00:00:00 2001 From: aarna Date: Thu, 7 Nov 2024 18:38:38 +0530 Subject: [PATCH 1/5] Boolean Rewrite Task --- .../RewriteAlgebraicSimplificationStatic.java | 1096 +++++++++-------- .../RewriteBooleanSimplificationTest.java | 76 ++ .../RewriteBooleanSimplificationTestAnd.dml | 5 + .../RewriteBooleanSimplificationTestOr.dml | 5 + 4 files changed, 657 insertions(+), 525 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml create mode 100644 src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index a18a2b74660..a32137af470 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -62,26 +62,26 @@ * estimate, in MR this allows map-only operations and hence prevents * unnecessary shuffle and sort) and (2) remove binary operations that * are in itself are unnecessary (e.g., *1 and /1). - * + * */ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { //valid aggregation operation types for rowOp to colOp conversions and vice versa private static final AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[] { - AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR}; - + AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR}; + //valid binary operations for distributive and associate reorderings - private static final OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[] {OpOp2.PLUS, OpOp2.MINUS}; + private static final OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[] {OpOp2.PLUS, OpOp2.MINUS}; private static final OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[] {OpOp2.PLUS, OpOp2.MULT}; - + //valid binary operations for scalar operations - private static final OpOp2[] LOOKUP_VALID_SCALAR_BINARY = new OpOp2[] {OpOp2.AND, OpOp2.DIV, - OpOp2.EQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.INTDIV, OpOp2.LESS, OpOp2.LESSEQUAL, - OpOp2.LOG, OpOp2.MAX, OpOp2.MIN, OpOp2.MINUS, OpOp2.MODULUS, OpOp2.MULT, OpOp2.NOTEQUAL, - OpOp2.OR, OpOp2.PLUS, OpOp2.POW}; - + private static final OpOp2[] LOOKUP_VALID_SCALAR_BINARY = new OpOp2[] {OpOp2.AND, OpOp2.DIV, + OpOp2.EQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.INTDIV, OpOp2.LESS, OpOp2.LESSEQUAL, + OpOp2.LOG, OpOp2.MAX, OpOp2.MIN, OpOp2.MINUS, OpOp2.MODULUS, OpOp2.MULT, OpOp2.NOTEQUAL, + OpOp2.OR, OpOp2.PLUS, OpOp2.POW}; + @Override - public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) + public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { if( roots == null ) return roots; @@ -90,32 +90,32 @@ public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus for( Hop h : roots ) rule_AlgebraicSimplification( h, false ); Hop.resetVisitStatus(roots, true); - + //one pass descend-rewrite (for rollup) for( Hop h : roots ) rule_AlgebraicSimplification( h, true ); Hop.resetVisitStatus(roots, true); - + //cleanup remove (twrite <- tread) pairs (unless checkpointing) removeTWriteTReadPairs(roots); - + return roots; } @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { if( root == null ) return root; - + //one pass rewrite-descend (rewrite created pattern) rule_AlgebraicSimplification( root, false ); root.resetVisitStatus(); - + //one pass descend-rewrite (for rollup) rule_AlgebraicSimplification( root, true ); - + return root; } @@ -125,24 +125,24 @@ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) * (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should * come before constant folding while the other simplifications should come after constant * folding. Hence, not applied yet. - * + * * @param hop high-level operator * @param descendFirst if process children recursively first */ - private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) + private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) { if(hop.isVisited()) return; - + //recursively process children for( int i=0; i 1/X hi = removeUnnecessaryBinaryOperation(hop, hi, i); //e.g., X*1 -> X (dep: should come after rm unnecessary vectorize) @@ -153,6 +153,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) hi = simplifyCTableWithConstMatrixInputs(hi); //e.g., table(X, matrix(1,...)) -> table(X, 1) hi = removeUnnecessaryCTable(hop, hi, i); //e.g., sum(table(X, 1)) -> nrow(X) and sum(table(1, Y)) -> nrow(Y) and sum(table(X, Y)) -> nrow(X) + hi = simplifyBooleanRewrite(hop, hi, i); hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y @@ -169,11 +170,11 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) - hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); + hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]; hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1] - hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq; - hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq; + hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq; + hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq; hi = fuseOrderOperationChain(hi); //e.g., order(order(X,2),1) -> order(X,(12)) hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites hi = removeUnnecessaryRemoveEmpty(hop, hi, i); //e.g., nrow(removeEmpty(A)) -> nnz(A) iff col vector @@ -187,15 +188,15 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) } hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) - hi = simplifyBinaryComparisonChain(hop, hi, i); //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="), + hi = simplifyBinaryComparisonChain(hop, hi, i); //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="), hi = simplifyCumsumColOrFullAggregates(hi); //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) hi = simplifyCumsumReverse(hop, hi, i); //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X) - hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B) + hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B) //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) hi = fixNonScalarPrint(hop, hi, i); //e.g., print(m) -> print(toString(m)) - + //process childs recursively after rewrites (to investigate pattern newly created by rewrites) if( !descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); @@ -203,21 +204,21 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hop.setVisited(); } - + private static Hop removeUnnecessaryVectorizeOperation(Hop hi) { - //applies to all binary matrix operations, if one input is unnecessarily vectorized - if( hi instanceof BinaryOp && hi.getDataType()==DataType.MATRIX - && ((BinaryOp)hi).supportsMatrixScalarOperations() ) + //applies to all binary matrix operations, if one input is unnecessarily vectorized + if( hi instanceof BinaryOp && hi.getDataType()==DataType.MATRIX + && ((BinaryOp)hi).supportsMatrixScalarOperations() ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); - - //NOTE: these rewrites of binary cell operations need to be aware that right is + + //NOTE: these rewrites of binary cell operations need to be aware that right is //potentially a vector but the result is of the size of left //TODO move to dynamic rewrites (since size dependent to account for mv binary cell and outer operations) - + if( !(left.getDim1()>1 && left.getDim2()==1 && right.getDim1()==1 && right.getDim2()>1) ) // no outer { //check and remove right vectorized scalar @@ -229,7 +230,7 @@ private static Hop removeUnnecessaryVectorizeOperation(Hop hi) Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN)); HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1); HopRewriteUtils.cleanupUnreferenced(dright); - + LOG.debug("Applied removeUnnecessaryVectorizeOperation1"); } } @@ -238,50 +239,50 @@ else if( right.getDataType() == DataType.MATRIX && left instanceof DataGenOp ) { DataGenOp dleft = (DataGenOp) left; if( dleft.getOp()==OpOpDG.RAND && dleft.hasConstantValue() - && (left.getDim2()==1 || right.getDim2()>1) - && (left.getDim1()==1 || right.getDim1()>1)) + && (left.getDim2()==1 || right.getDim2()>1) + && (left.getDim1()==1 || right.getDim1()>1)) { Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN)); HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0); HopRewriteUtils.cleanupUnreferenced(dleft); - + LOG.debug("Applied removeUnnecessaryVectorizeOperation2"); } } //Note: we applied this rewrite to at most one side in order to keep the //output semantically equivalent. However, future extensions might consider - //to remove vectors from both side, compute the binary op on scalars and + //to remove vectors from both side, compute the binary op on scalars and //finally feed it into a datagenop of the original dimensions. } } - + return hi; } - - + + /** * handle removal of unnecessary binary operations - * + * * X/1 or X*1 or 1*X or X-0 -> X * -1*X or X*-1-> -X - * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ - private static Hop removeUnnecessaryBinaryOperation( Hop parent, Hop hi, int pos ) + private static Hop removeUnnecessaryBinaryOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); - //X/1 or X*1 -> X - if( left.getDataType()==DataType.MATRIX - && right instanceof LiteralOp && right.getValueType().isNumeric() - && ((LiteralOp)right).getDoubleValue()==1.0 ) + //X/1 or X*1 -> X + if( left.getDataType()==DataType.MATRIX + && right instanceof LiteralOp && right.getValueType().isNumeric() + && ((LiteralOp)right).getDoubleValue()==1.0 ) { if( bop.getOp()==OpOp2.DIV || bop.getOp()==OpOp2.MULT ) { @@ -291,8 +292,8 @@ private static Hop removeUnnecessaryBinaryOperation( Hop parent, Hop hi, int pos LOG.debug("Applied removeUnnecessaryBinaryOperation1 (line "+bop.getBeginLine()+")"); } } - //X-0 -> X - else if( left.getDataType()==DataType.MATRIX + //X-0 -> X + else if( left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && right.getValueType().isNumeric() && ((LiteralOp)right).getDoubleValue()==0.0 ) { @@ -305,7 +306,7 @@ else if( left.getDataType()==DataType.MATRIX } } //1*X -> X - else if( right.getDataType()==DataType.MATRIX + else if( right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && left.getValueType().isNumeric() && ((LiteralOp)left).getDoubleValue()==1.0 ) { @@ -318,9 +319,9 @@ else if( right.getDataType()==DataType.MATRIX } } //-1*X -> -X - //note: this rewrite is necessary since the new antlr parser always converts + //note: this rewrite is necessary since the new antlr parser always converts //-X to -1*X due to mechanical reasons - else if( right.getDataType()==DataType.MATRIX + else if( right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && left.getValueType().isNumeric() && ((LiteralOp)left).getDoubleValue()==-1.0 ) { @@ -334,7 +335,7 @@ else if( right.getDataType()==DataType.MATRIX } } //X*-1 -> -X (see comment above) - else if( left.getDataType()==DataType.MATRIX + else if( left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && right.getValueType().isNumeric() && ((LiteralOp)right).getDoubleValue()==-1.0 ) { @@ -344,85 +345,130 @@ else if( left.getDataType()==DataType.MATRIX HopRewriteUtils.removeChildReferenceByPos(bop, right, 1); HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0); hi = bop; - + LOG.debug("Applied removeUnnecessaryBinaryOperation5 (line "+bop.getBeginLine()+")"); } } } - + return hi; } - + + public static Hop simplifyBooleanRewrite(Hop parent, Hop hi, int pos) { + if (hi instanceof BinaryOp) { + BinaryOp bop = (BinaryOp) hi; + Hop left = hi.getInput().get(0); + Hop right = hi.getInput().get(1); + + // Pattern: a & !a --> FALSE + if (bop.getOp() == OpOp2.AND + && HopRewriteUtils.isUnary(right, OpOp1.NOT) + && left == right.getInput().get(0)) { + + LiteralOp falseOp = new LiteralOp(false); + + // Ensure parent has the input before attempting replacement + if (parent != null && parent.getInput().size() > pos) { + HopRewriteUtils.replaceChildReference(parent, hi, falseOp, pos); + HopRewriteUtils.cleanupUnreferenced(hi, left, right); + hi = falseOp; + } + + LOG.debug("Applied simplifyBooleanRewrite1 (line " + hi.getBeginLine() + ")."); + } + // Pattern: a | !a --> TRUE + else if (bop.getOp() == OpOp2.OR + && HopRewriteUtils.isUnary(right, OpOp1.NOT) + && left == right.getInput().get(0)) { + + LiteralOp trueOp = new LiteralOp(true); + + // Ensure parent has the input before attempting replacement + if (parent != null && parent.getInput().size() > pos) { + HopRewriteUtils.replaceChildReference(parent, hi, trueOp, pos); + HopRewriteUtils.cleanupUnreferenced(hi, left, right); + hi = trueOp; + } + + LOG.debug("Applied simplifyBooleanRewrite2 (line " + hi.getBeginLine() + ")."); + } + } + + return hi; + } + + + /** * Handle removal of unnecessary binary operations over rand data - * + * * rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 -> rand(min+(-7),max+(-7)) * 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7); - * + * * @param hi high-order operation * @return high-level operator */ @SuppressWarnings("incomplete-switch") - private static Hop fuseDatagenAndBinaryOperation( Hop hi ) + private static Hop fuseDatagenAndBinaryOperation( Hop hi ) { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); - + //NOTE: rewrite not applied if more than one datagen consumer because this would lead to //the creation of multiple datagen ops and thus potentially different results if seed not specified) - + //left input rand and hence output matrix double, right scalar literal if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND) && - right instanceof LiteralOp && left.getParent().size()==1 ) + right instanceof LiteralOp && left.getParent().size()==1 ) { DataGenOp inputGen = (DataGenOp)left; Hop pdf = inputGen.getInput(DataExpression.RAND_PDF); Hop min = inputGen.getInput(DataExpression.RAND_MIN); Hop max = inputGen.getInput(DataExpression.RAND_MAX); double sval = ((LiteralOp)right).getDoubleValue(); - boolean pdfUniform = pdf instanceof LiteralOp - && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); - + boolean pdfUniform = pdf instanceof LiteralOp + && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); + if( HopRewriteUtils.isBinary(bop, OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV) - && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform ) + && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform ) { //create fused data gen operator DataGenOp gen = null; switch( bop.getOp() ) { //fuse via scale and shift case MULT: gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0); break; case PLUS: - case MINUS: gen = HopRewriteUtils.copyDataGenOp(inputGen, - 1, sval * ((bop.getOp()==OpOp2.MINUS)?-1:1)); break; + case MINUS: gen = HopRewriteUtils.copyDataGenOp(inputGen, + 1, sval * ((bop.getOp()==OpOp2.MINUS)?-1:1)); break; case DIV: gen = HopRewriteUtils.copyDataGenOp(inputGen, 1/sval, 0); break; } - + //rewire all parents (avoid anomalies with replicated datagen) List parents = new ArrayList<>(bop.getParent()); for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, bop, gen); - + hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation1 " - + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); + + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); } } //right input rand and hence output matrix double, left scalar literal else if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==OpOpDG.RAND && - left instanceof LiteralOp && right.getParent().size()==1 ) + left instanceof LiteralOp && right.getParent().size()==1 ) { DataGenOp inputGen = (DataGenOp)right; Hop pdf = inputGen.getInput(DataExpression.RAND_PDF); Hop min = inputGen.getInput(DataExpression.RAND_MIN); Hop max = inputGen.getInput(DataExpression.RAND_MAX); double sval = ((LiteralOp)left).getDoubleValue(); - boolean pdfUniform = pdf instanceof LiteralOp - && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); - + boolean pdfUniform = pdf instanceof LiteralOp + && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); + if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS) - && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform ) + && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform ) { //create fused data gen operator DataGenOp gen = null; @@ -431,32 +477,32 @@ else if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==OpOpDG.RAND & else { //OpOp2.PLUS gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval); } - + //rewire all parents (avoid anomalies with replicated datagen) List parents = new ArrayList<>(bop.getParent()); for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, bop, gen); - + hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation2 " - + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); + + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); } } //left input rand and hence output matrix double, right scalar variable - else if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND) - && right.getDataType().isScalar() && left.getParent().size()==1 ) + else if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND) + && right.getDataType().isScalar() && left.getParent().size()==1 ) { DataGenOp gen = (DataGenOp)left; Hop min = gen.getInput(DataExpression.RAND_MIN); Hop max = gen.getInput(DataExpression.RAND_MAX); Hop pdf = gen.getInput(DataExpression.RAND_PDF); - boolean pdfUniform = pdf instanceof LiteralOp - && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); - - + boolean pdfUniform = pdf instanceof LiteralOp + && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); + + if( HopRewriteUtils.isBinary(bop, OpOp2.PLUS) - && HopRewriteUtils.isLiteralOfValue(min, 0) - && HopRewriteUtils.isLiteralOfValue(max, 0) ) + && HopRewriteUtils.isLiteralOfValue(min, 0) + && HopRewriteUtils.isLiteralOfValue(max, 0) ) { gen.setInput(DataExpression.RAND_MIN, right, true); gen.setInput(DataExpression.RAND_MAX, right, true); @@ -466,12 +512,12 @@ else if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND) HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation3a " - + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); + + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); } else if( HopRewriteUtils.isBinary(bop, OpOp2.MULT) - && ((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform) + && ((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform) || HopRewriteUtils.isLiteralOfValue(min, 1)) - && HopRewriteUtils.isLiteralOfValue(max, 1) ) + && HopRewriteUtils.isLiteralOfValue(max, 1) ) { if( HopRewriteUtils.isLiteralOfValue(min, 1) ) gen.setInput(DataExpression.RAND_MIN, right, true); @@ -482,24 +528,24 @@ else if( HopRewriteUtils.isBinary(bop, OpOp2.MULT) HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation3b " - + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); + + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); } } } - + return hi; } - - private static Hop fuseDatagenAndMinusOperation( Hop hi ) + + private static Hop fuseDatagenAndMinusOperation( Hop hi ) { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); - + if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==OpOpDG.RAND && - left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==0.0 ) + left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==0.0 ) { DataGenOp inputGen = (DataGenOp)right; HashMap params = inputGen.getParamIndexMap(); @@ -508,55 +554,55 @@ private static Hop fuseDatagenAndMinusOperation( Hop hi ) int ixMax = params.get(DataExpression.RAND_MAX); Hop min = right.getInput().get(ixMin); Hop max = right.getInput().get(ixMax); - + //apply rewrite under additional conditions (for simplicity) - if( inputGen.getParent().size()==1 - && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp - && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) ) + if( inputGen.getParent().size()==1 + && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp + && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) ) { //exchange and *-1 (special case 0 stays 0 instead of -0 for consistency) double newMinVal = (((LiteralOp)max).getDoubleValue()==0)?0:(-1 * ((LiteralOp)max).getDoubleValue()); double newMaxVal = (((LiteralOp)min).getDoubleValue()==0)?0:(-1 * ((LiteralOp)min).getDoubleValue()); Hop newMin = new LiteralOp(newMinVal); Hop newMax = new LiteralOp(newMaxVal); - + HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin); HopRewriteUtils.addChildReference(inputGen, newMin, ixMin); HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax); HopRewriteUtils.addChildReference(inputGen, newMax, ixMax); - + //rewire all parents (avoid anomalies with replicated datagen) List parents = new ArrayList<>(bop.getParent()); for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, bop, inputGen); - + hi = inputGen; LOG.debug("Applied fuseDatagenAndMinusOperation (line "+bop.getBeginLine()+")."); } } } - + return hi; } - - private static Hop foldMultipleAppendOperations(Hop hi) + + private static Hop foldMultipleAppendOperations(Hop hi) { if( hi.getDataType().isMatrix() //no string appends or frames - && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND, OpOp2.RBIND) - || HopRewriteUtils.isNary(hi, OpOpN.CBIND, OpOpN.RBIND)) ) + && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND, OpOp2.RBIND) + || HopRewriteUtils.isNary(hi, OpOpN.CBIND, OpOpN.RBIND)) ) { OpOp2 bop = (hi instanceof BinaryOp) ? ((BinaryOp)hi).getOp() : - OpOp2.valueOf(((NaryOp)hi).getOp().name()); + OpOp2.valueOf(((NaryOp)hi).getOp().name()); OpOpN nop = (hi instanceof NaryOp) ? ((NaryOp)hi).getOp() : - OpOpN.valueOf(((BinaryOp)hi).getOp().name()); - + OpOpN.valueOf(((BinaryOp)hi).getOp().name()); + boolean converged = false; while( !converged ) { //get first matching cbind or rbind Hop first = hi.getInput().stream() - .filter(h -> HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop)) - .findFirst().orElse(null); - + .filter(h -> HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop)) + .findFirst().orElse(null); + //replace current op with new nary cbind/rbind if( first != null && first.getParent().size()==1 ) { //construct new list of inputs (in original order) @@ -582,29 +628,29 @@ private static Hop foldMultipleAppendOperations(Hop hi) } } } - + return hi; } - + /** * Handle simplification of binary operations (relies on previous common subexpression elimination). * At the same time this servers as a canonicalization for more complex rewrites. - * + * * X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X) - * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ - private static Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi, int pos ) + private static Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); - + //patterns: X+X -> X*2, X*X -> X^2, if( left == right && left.getDataType()==DataType.MATRIX ) { @@ -614,48 +660,48 @@ private static Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi, int pos ) { bop.setOp(OpOp2.MULT); HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1); - + LOG.debug("Applied simplifyBinaryToUnaryOperation1 (line "+hi.getBeginLine()+")."); } else if ( bop.getOp()==OpOp2.MULT ) //X*X -> X^2 { bop.setOp(OpOp2.POW); HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1); - + LOG.debug("Applied simplifyBinaryToUnaryOperation2 (line "+hi.getBeginLine()+")."); } } //patterns: (X>0)-(X<0) -> sign(X) - else if( bop.getOp() == OpOp2.MINUS - && HopRewriteUtils.isBinary(left, OpOp2.GREATER) - && HopRewriteUtils.isBinary(right, OpOp2.LESS) - && left.getInput().get(0) == right.getInput().get(0) - && left.getInput().get(1) instanceof LiteralOp - && HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0 - && right.getInput().get(1) instanceof LiteralOp - && HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 ) + else if( bop.getOp() == OpOp2.MINUS + && HopRewriteUtils.isBinary(left, OpOp2.GREATER) + && HopRewriteUtils.isBinary(right, OpOp2.LESS) + && left.getInput().get(0) == right.getInput().get(0) + && left.getInput().get(1) instanceof LiteralOp + && HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0 + && right.getInput().get(1) instanceof LiteralOp + && HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 ) { UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN); HopRewriteUtils.replaceChildReference(parent, hi, uop, pos); HopRewriteUtils.cleanupUnreferenced(hi, left, right); hi = uop; - + LOG.debug("Applied simplifyBinaryToUnaryOperation3 (line "+hi.getBeginLine()+")."); } } - + return hi; } - + /** * Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and * U%*%V-eps into the common representation U%*%V+s which simplifies * subsequent rewrites (e.g., wdivmm or wcemm with epsilon). - * + * * @param hi high-level operator * @return high-level operator */ - private static Hop canonicalizeMatrixMultScalarAdd( Hop hi ) + private static Hop canonicalizeMatrixMultScalarAdd( Hop hi ) { //pattern: binary operation (+ or -) of matrix mult and scalar if( hi instanceof BinaryOp ) @@ -663,10 +709,10 @@ private static Hop canonicalizeMatrixMultScalarAdd( Hop hi ) BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); - + //pattern: (eps + U%*%V) -> (U%*%V+eps) if( left.getDataType().isScalar() && right instanceof AggBinaryOp - && bop.getOp()==OpOp2.PLUS ) + && bop.getOp()==OpOp2.PLUS ) { HopRewriteUtils.removeAllChildReferences(bop); HopRewriteUtils.addChildReference(bop, right, 0); @@ -683,11 +729,11 @@ else if( right.getDataType().isScalar() && left instanceof AggBinaryOp LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+")."); } } - + return hi; } - - private static Hop simplifyCTableWithConstMatrixInputs( Hop hi ) + + private static Hop simplifyCTableWithConstMatrixInputs( Hop hi ) { //pattern: table(X, matrix(1,...), matrix(7, ...)) -> table(X, 1, 7) if( HopRewriteUtils.isTernary(hi, OpOp3.CTABLE) ) { @@ -698,7 +744,7 @@ private static Hop simplifyCTableWithConstMatrixInputs( Hop hi ) Hop inNew = ((DataGenOp)inCurr).getInput(DataExpression.RAND_MIN); HopRewriteUtils.replaceChildReference(hi, inCurr, inNew, i); LOG.debug("Applied simplifyCTableWithConstMatrixInputs" - + i + " (line "+hi.getBeginLine()+")."); + + i + " (line "+hi.getBeginLine()+")."); } } } @@ -706,9 +752,9 @@ private static Hop simplifyCTableWithConstMatrixInputs( Hop hi ) } private static Hop removeUnnecessaryCTable( Hop parent, Hop hi, int pos ) { - if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol) - && HopRewriteUtils.isTernary(hi.getInput().get(0), OpOp3.CTABLE) - && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0).getInput().get(2), 1.0)) + if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol) + && HopRewriteUtils.isTernary(hi.getInput().get(0), OpOp3.CTABLE) + && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0).getInput().get(2), 1.0)) { Hop matrixInput = hi.getInput().get(0).getInput().get(0); OpOp1 opcode = matrixInput.getDim2() == 1 ? OpOp1.NROW : OpOp1.LENGTH; @@ -724,67 +770,67 @@ private static Hop removeUnnecessaryCTable( Hop parent, Hop hi, int pos ) { * NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static * rewrite in order to apply it before splitting dags which would hide the table information * if dimensions are not specified. - * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ - private static Hop simplifyReverseOperation( Hop parent, Hop hi, int pos ) + private static Hop simplifyReverseOperation( Hop parent, Hop hi, int pos ) { - if( hi instanceof AggBinaryOp - && hi.getInput().get(0) instanceof TernaryOp ) + if( hi instanceof AggBinaryOp + && hi.getInput().get(0) instanceof TernaryOp ) { TernaryOp top = (TernaryOp) hi.getInput().get(0); - + if( top.getOp()==OpOp3.CTABLE - && HopRewriteUtils.isBasic1NSequence(top.getInput().get(0)) - && HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1)) - && top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1()) + && HopRewriteUtils.isBasic1NSequence(top.getInput().get(0)) + && HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1)) + && top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1()) { ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV); HopRewriteUtils.replaceChildReference(parent, hi, rop, pos); HopRewriteUtils.cleanupUnreferenced(hi, top); hi = rop; - + LOG.debug("Applied simplifyReverseOperation."); } } - + return hi; } - + private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi ) { //pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate) if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS) - && hi.getDataType() == DataType.MATRIX - && hi.getInput().get(0) instanceof LiteralOp - && HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1 - && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) - && hi.getInput().get(1).getParent().size() == 1 ) //single consumer + && hi.getDataType() == DataType.MATRIX + && hi.getInput().get(0) instanceof LiteralOp + && HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1 + && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) + && hi.getInput().get(1).getParent().size() == 1 ) //single consumer { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(1).getInput().get(0); Hop right = hi.getInput().get(1).getInput().get(1); - + //set new binaryop type and rewire inputs bop.setOp(OpOp2.MINUS1_MULT); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.addChildReference(bop, left); HopRewriteUtils.addChildReference(bop, right); - + LOG.debug("Applied simplifyMultiBinaryToBinaryOperation."); } - + return hi; } - + /** * (X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X * (X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X - * - * + * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position @@ -797,21 +843,21 @@ private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); - + //(X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X //(X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X boolean applied = false; - if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX - && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY) ) + if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX + && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY) ) { Hop X = null; Hop Y = null; if( HopRewriteUtils.isBinary(left, OpOp2.MULT) ) //(Y*X-X) -> (Y-1)*X { Hop leftC1 = left.getInput().get(0); Hop leftC2 = left.getInput().get(1); - + if( leftC1.getDataType()==DataType.MATRIX && leftC2.getDataType()==DataType.MATRIX && - (right == leftC1 || right == leftC2) && leftC1 !=leftC2 ){ //any mult order + (right == leftC1 || right == leftC2) && leftC1 !=leftC2 ){ //any mult order X = right; Y = ( right == leftC1 ) ? leftC2 : leftC1; } @@ -823,17 +869,17 @@ private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int HopRewriteUtils.cleanupUnreferenced(hi, left); hi = mult; applied = true; - + LOG.debug("Applied simplifyDistributiveBinaryOperation1 (line "+hi.getBeginLine()+")."); } } - + if( !applied && HopRewriteUtils.isBinary(right, OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X { Hop rightC1 = right.getInput().get(0); Hop rightC2 = right.getInput().get(1); if( rightC1.getDataType()==DataType.MATRIX && rightC2.getDataType()==DataType.MATRIX && - (left == rightC1 || left == rightC2) && rightC1 !=rightC2 ){ //any mult order + (left == rightC1 || left == rightC2) && rightC1 !=rightC2 ){ //any mult order X = left; Y = ( left == rightC1 ) ? rightC2 : rightC1; } @@ -847,21 +893,21 @@ private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int LOG.debug("Applied simplifyDistributiveBinaryOperation2 (line "+hi.getBeginLine()+")."); } - } + } } } - + return hi; } - + /** * t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%(X*Y)*(Z%*%v) * t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v)) - * + * * Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too * eagerly, which would loose additional rewrite potential. This rewrite has two goals * (1) enable XtwXv, and increase piggybacking potential by creating bushy trees. - * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position @@ -875,46 +921,46 @@ private static Hop simplifyBushyBinaryOperation( Hop parent, Hop hi, int pos ) Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); OpOp2 op = bop.getOp(); - + if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX && - HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY) ) + HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY) ) { boolean applied = false; - + if( right instanceof BinaryOp ) { BinaryOp bop2 = (BinaryOp)right; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); OpOp2 op2 = bop2.getOp(); - - if( op==op2 && right2.getDataType()==DataType.MATRIX - && (right2 instanceof AggBinaryOp) ) + + if( op==op2 && right2.getDataType()==DataType.MATRIX + && (right2 instanceof AggBinaryOp) ) { //(X*(Y*op()) -> (X*Y)*op() BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op); BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op); - HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos); + HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2); hi = bop4; - + applied = true; - + LOG.debug("Applied simplifyBushyBinaryOperation1"); } } - + if( !applied && left instanceof BinaryOp ) { BinaryOp bop2 = (BinaryOp)left; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); OpOp2 op2 = bop2.getOp(); - - if( op==op2 && left2.getDataType()==DataType.MATRIX - && (left2 instanceof AggBinaryOp) - && (right2.getDim2() > 1 || right.getDim2() == 1) //X not vector, or Y vector - && (right2.getDim1() > 1 || right.getDim1() == 1) ) //X not vector, or Y vector + + if( op==op2 && left2.getDataType()==DataType.MATRIX + && (left2 instanceof AggBinaryOp) + && (right2.getDim2() > 1 || right.getDim2() == 1) //X not vector, or Y vector + && (right2.getDim1() > 1 || right.getDim1() == 1) ) //X not vector, or Y vector { //((op()*X)*Y) -> op()*(X*Y) BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op); @@ -922,39 +968,39 @@ private static Hop simplifyBushyBinaryOperation( Hop parent, Hop hi, int pos ) HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2); hi = bop4; - + LOG.debug("Applied simplifyBushyBinaryOperation2"); } } } - + } - + return hi; } - + private static Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg - && hi.getInput().get(0) instanceof ReorgOp ) //reorg operation + && hi.getInput().get(0) instanceof ReorgOp ) //reorg operation { ReorgOp rop = (ReorgOp)hi.getInput().get(0); if( (rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE || rop.getOp() == ReOrgOp.REV ) //valid reorg - && rop.getParent().size()==1 ) //uagg only reorg consumer + && rop.getParent().size()==1 ) //uagg only reorg consumer { Hop input = rop.getInput().get(0); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.removeAllChildReferences(rop); HopRewriteUtils.addChildReference(hi, input); - + LOG.debug("Applied simplifyUnaryAggReorgOperation"); } } - + return hi; } - + private static Hop removeUnnecessaryAggregates(Hop hi) { //sum(rowSums(X)) -> sum(X), sum(colSums(X)) -> sum(X) @@ -962,44 +1008,44 @@ private static Hop removeUnnecessaryAggregates(Hop hi) //max(rowMaxs(X)) -> max(X), max(colMaxs(X)) -> max(X) //sum(rowSums(X^2)) -> sum(X), sum(colSums(X^2)) -> sum(X) if( hi instanceof AggUnaryOp && hi.getInput().get(0) instanceof AggUnaryOp - && ((AggUnaryOp)hi).getDirection()==Direction.RowCol - && hi.getInput().get(0).getParent().size()==1 ) + && ((AggUnaryOp)hi).getDirection()==Direction.RowCol + && hi.getInput().get(0).getParent().size()==1 ) { AggUnaryOp au1 = (AggUnaryOp) hi; AggUnaryOp au2 = (AggUnaryOp) hi.getInput().get(0); - if( (au1.getOp()==AggOp.SUM && (au2.getOp()==AggOp.SUM || au2.getOp()==AggOp.SUM_SQ)) - || (au1.getOp()==AggOp.MIN && au2.getOp()==AggOp.MIN) - || (au1.getOp()==AggOp.MAX && au2.getOp()==AggOp.MAX) ) + if( (au1.getOp()==AggOp.SUM && (au2.getOp()==AggOp.SUM || au2.getOp()==AggOp.SUM_SQ)) + || (au1.getOp()==AggOp.MIN && au2.getOp()==AggOp.MIN) + || (au1.getOp()==AggOp.MAX && au2.getOp()==AggOp.MAX) ) { Hop input = au2.getInput().get(0); HopRewriteUtils.removeAllChildReferences(au2); HopRewriteUtils.replaceChildReference(au1, au2, input); if( au2.getOp() == AggOp.SUM_SQ ) au1.setOp(AggOp.SUM_SQ); - + LOG.debug("Applied removeUnnecessaryAggregates (line "+hi.getBeginLine()+")."); } } - + return hi; } - - private static Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) + + private static Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) { // Note: This rewrite is not applicable for all binary operations because some of them // are undefined over scalars. We explicitly exclude potential conflicting matrix-scalar binary // operations; other operations like cbind/rbind will never occur as matrix-scalar operations. - - if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) - && hi.getInput().get(0) instanceof BinaryOp - && HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY)) + + if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) + && hi.getInput().get(0) instanceof BinaryOp + && HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY)) { BinaryOp bin = (BinaryOp) hi.getInput().get(0); BinaryOp bout = null; - + //as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y) - if( bin.getInput().get(0).getDataType()==DataType.MATRIX - && bin.getInput().get(1).getDataType()==DataType.MATRIX ) { + if( bin.getInput().get(0).getDataType()==DataType.MATRIX + && bin.getInput().get(1).getDataType()==DataType.MATRIX ) { UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR); UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR); bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp()); @@ -1014,86 +1060,86 @@ else if ( bin.getInput().get(1).getDataType()==DataType.MATRIX ) { UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR); bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp()); } - + if( bout != null ) { HopRewriteUtils.replaceChildReference(parent, hi, bout, pos); - + LOG.debug("Applied simplifyBinaryMatrixScalarOperation."); } } - + return hi; } - + private static Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop hi, int pos ) { - if( hi instanceof AggUnaryOp && hi.getParent().size()==1 - && (((AggUnaryOp) hi).getDirection()==Direction.Row || ((AggUnaryOp) hi).getDirection()==Direction.Col) - && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1) - && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) + if( hi instanceof AggUnaryOp && hi.getParent().size()==1 + && (((AggUnaryOp) hi).getDirection()==Direction.Row || ((AggUnaryOp) hi).getDirection()==Direction.Col) + && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1) + && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) { AggUnaryOp uagg = (AggUnaryOp) hi; - + //get input rewire existing operators (remove inner transpose) Hop input = uagg.getInput().get(0).getInput().get(0); HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0)); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); - + //pattern 1: row-aggregate to col aggregate, e.g., rowSums(t(X))->t(colSums(X)) if( uagg.getDirection()==Direction.Row ) { - uagg.setDirection(Direction.Col); - LOG.debug("Applied pushdownUnaryAggTransposeOperation1 (line "+hi.getBeginLine()+")."); + uagg.setDirection(Direction.Col); + LOG.debug("Applied pushdownUnaryAggTransposeOperation1 (line "+hi.getBeginLine()+")."); } //pattern 2: col-aggregate to row aggregate, e.g., colSums(t(X))->t(rowSums(X)) else if( uagg.getDirection()==Direction.Col ) { - uagg.setDirection(Direction.Row); + uagg.setDirection(Direction.Row); LOG.debug("Applied pushdownUnaryAggTransposeOperation2 (line "+hi.getBeginLine()+")."); } - + //create outer transpose operation and rewire operators HopRewriteUtils.addChildReference(uagg, input); uagg.refreshSizeInformation(); Hop trans = HopRewriteUtils.createTranspose(uagg); //incl refresh size HopRewriteUtils.addChildReference(parent, trans, pos); //by def, same size - - hi = trans; + + hi = trans; } - + return hi; } - + private static Hop pushdownCSETransposeScalarOperation( Hop parent, Hop hi, int pos ) { // a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) // probed at root node of b in above example // (with support for left or right scalar operations) - if( HopRewriteUtils.isTransposeOperation(hi, 1) - && HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0)) - && hi.getInput().get(0).getParent().size()==1) + if( HopRewriteUtils.isTransposeOperation(hi, 1) + && HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0)) + && hi.getInput().get(0).getParent().size()==1) { int Xpos = hi.getInput().get(0).getInput().get(0).getDataType().isMatrix() ? 0 : 1; Hop X = hi.getInput().get(0).getInput().get(Xpos); BinaryOp binary = (BinaryOp) hi.getInput().get(0); - - if( HopRewriteUtils.containsTransposeOperation(X.getParent()) - && !HopRewriteUtils.isValidOp(binary.getOp(), new OpOp2[]{OpOp2.MOMENT, OpOp2.QUANTILE})) + + if( HopRewriteUtils.containsTransposeOperation(X.getParent()) + && !HopRewriteUtils.isValidOp(binary.getOp(), new OpOp2[]{OpOp2.MOMENT, OpOp2.QUANTILE})) { //clear existing wiring - HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); + HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.removeChildReference(hi, binary); HopRewriteUtils.removeChildReference(binary, X); - + //re-wire operators HopRewriteUtils.addChildReference(parent, binary, pos); HopRewriteUtils.addChildReference(binary, hi, Xpos); HopRewriteUtils.addChildReference(hi, X); //note: common subexpression later eliminated by dedicated rewrite - + hi = binary; LOG.debug("Applied pushdownCSETransposeScalarOperation (line "+hi.getBeginLine()+")."); - } + } } - + return hi; } @@ -1103,20 +1149,20 @@ private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) { && ((AggUnaryOp)hi).getOp()==AggOp.SUM // only one parent which is the sum && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1) && ((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX) - ||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR))) + ||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR))) { - Hop operand1 = hi.getInput().get(0).getInput().get(0); + Hop operand1 = hi.getInput().get(0).getInput().get(0); Hop operand2 = hi.getInput().get(0).getInput().get(1); //check which operand is the Scalar and which is the matrix - Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2; - Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2; + Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2; + Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2; AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol); Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT); - + HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); - + LOG.debug("Applied pushdownSumBinaryMult."); return bop; } @@ -1125,89 +1171,89 @@ private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) { private static Hop pullupAbs(Hop parent, Hop hi, int pos ) { if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) - && HopRewriteUtils.isUnary(hi.getInput(0), OpOp1.ABS) - && hi.getInput(0).getParent().size()==1 - && HopRewriteUtils.isUnary(hi.getInput(1), OpOp1.ABS) - && hi.getInput(1).getParent().size()==1) + && HopRewriteUtils.isUnary(hi.getInput(0), OpOp1.ABS) + && hi.getInput(0).getParent().size()==1 + && HopRewriteUtils.isUnary(hi.getInput(1), OpOp1.ABS) + && hi.getInput(1).getParent().size()==1) { Hop operand1 = hi.getInput(0).getInput(0); Hop operand2 = hi.getInput(1).getInput(0); Hop bop = HopRewriteUtils.createBinary(operand1, operand2, OpOp2.MULT); Hop uop = HopRewriteUtils.createUnary(bop, OpOp1.ABS); HopRewriteUtils.replaceChildReference(parent, hi, uop, pos); - + LOG.debug("Applied pullupAbs (line "+hi.getBeginLine()+")."); return uop; } return hi; } - + private static Hop simplifyUnaryPPredOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof UnaryOp && hi.getDataType()==DataType.MATRIX //unaryop - && hi.getInput().get(0) instanceof BinaryOp //binaryop - ppred - && ((BinaryOp)hi.getInput().get(0)).isPPredOperation() ) + && hi.getInput().get(0) instanceof BinaryOp //binaryop - ppred + && ((BinaryOp)hi.getInput().get(0)).isPPredOperation() ) { UnaryOp uop = (UnaryOp) hi; //valid unary op if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN - || uop.getOp()==OpOp1.CEIL || uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND ) + || uop.getOp()==OpOp1.CEIL || uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND ) { //clear link unary-binary Hop input = uop.getInput().get(0); HopRewriteUtils.replaceChildReference(parent, hi, input, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = input; - - LOG.debug("Applied simplifyUnaryPPredOperation."); + + LOG.debug("Applied simplifyUnaryPPredOperation."); } } - + return hi; } - + private static Hop simplifyTransposedAppend( Hop parent, Hop hi, int pos ) { //e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B) if( HopRewriteUtils.isTransposeOperation(hi) //t() rooted - && hi.getInput().get(0) instanceof BinaryOp - && (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind) - || ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND) - && hi.getInput().get(0).getParent().size() == 1 ) //single consumer of append + && hi.getInput().get(0) instanceof BinaryOp + && (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind) + || ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND) + && hi.getInput().get(0).getParent().size() == 1 ) //single consumer of append { BinaryOp bop = (BinaryOp)hi.getInput().get(0); //both inputs transpose ops, where transpose is single consumer - if( HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1) - && HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) ) + if( HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1) + && HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) ) { Hop left = bop.getInput().get(0).getInput().get(0); Hop right = bop.getInput().get(1).getInput().get(0); - + //create new subdag (no in-place dag update to prevent anomalies with //multiple consumers during rewrite process) OpOp2 binop = (bop.getOp()==OpOp2.CBIND) ? OpOp2.RBIND : OpOp2.CBIND; BinaryOp bopnew = HopRewriteUtils.createBinary(left, right, binop); HopRewriteUtils.replaceChildReference(parent, hi, bopnew, pos); - + hi = bopnew; LOG.debug("Applied simplifyTransposedAppend (line "+hi.getBeginLine()+")."); } } - + return hi; } - + /** * handle simplification of more complex sub DAG to unary operation. - * + * * X*(1-X) -> sprop(X) * (1-X)*X -> sprop(X) * 1/(1+exp(-X)) -> sigmoid(X) - * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position */ - private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos ) + private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof BinaryOp ) { @@ -1215,7 +1261,7 @@ private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); boolean applied = false; - + //sample proportion (sprop) operator if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) { @@ -1223,97 +1269,97 @@ private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos //note: if there are multiple consumers on the intermediate, //we follow the heuristic that redundant computation is more beneficial, //i.e., we still fuse but leave the intermediate for the other consumers - + if( left instanceof BinaryOp ) //(1-X)*X { BinaryOp bleft = (BinaryOp)left; Hop left1 = bleft.getInput().get(0); - Hop left2 = bleft.getInput().get(1); - + Hop left2 = bleft.getInput().get(1); + if( left1 instanceof LiteralOp && - HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 && - left2 == right && bleft.getOp() == OpOp2.MINUS ) + HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 && + left2 == right && bleft.getOp() == OpOp2.MINUS ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; applied = true; - + LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1"); } - } + } if( !applied && right instanceof BinaryOp ) //X*(1-X) { BinaryOp bright = (BinaryOp)right; Hop right1 = bright.getInput().get(0); - Hop right2 = bright.getInput().get(1); - + Hop right2 = bright.getInput().get(1); + if( right1 instanceof LiteralOp && - HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 && - right2 == left && bright.getOp() == OpOp2.MINUS ) + HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 && + right2 == left && bright.getOp() == OpOp2.MINUS ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; applied = true; - + LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2"); } } } - + //sigmoid operator if( !applied && bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX - && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp) + && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp) { //note: if there are multiple consumers on the intermediate, //we follow the heuristic that redundant computation is more beneficial, //i.e., we still fuse but leave the intermediate for the other consumers - + BinaryOp bop2 = (BinaryOp)right; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); - + if( bop2.getOp() == OpOp2.PLUS && left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX - && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof UnaryOp) + && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof UnaryOp) { UnaryOp uop = (UnaryOp) right2; Hop uopin = uop.getInput().get(0); - - if( uop.getOp()==OpOp1.EXP ) + + if( uop.getOp()==OpOp1.EXP ) { UnaryOp unary = null; - + //Pattern 1: (1/(1 + exp(-X)) if( HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) { BinaryOp bop3 = (BinaryOp) uopin; Hop left3 = bop3.getInput().get(0); Hop right3 = bop3.getInput().get(1); - + if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 ) unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID); - } + } //Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by //the 'remove unnecessary minus' rewrite --> reintroduce the minus else { BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin); unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID); - } - + } + if( unary != null ) { HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop); hi = unary; applied = true; - + LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1"); - } + } } - } + } } - + //select positive (selp) operator (note: same initial pattern as sprop) if( !applied && bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) { @@ -1325,17 +1371,17 @@ private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos BinaryOp bleft = (BinaryOp)left; Hop left1 = bleft.getInput().get(0); Hop left2 = bleft.getInput().get(1); - + if( left2 instanceof LiteralOp && - HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 && - left1 == right && (bleft.getOp() == OpOp2.GREATER ) ) + HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 && + left1 == right && (bleft.getOp() == OpOp2.GREATER ) ) { BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX); HopRewriteUtils.replaceChildReference(parent, bop, binary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = binary; applied = true; - + LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a"); } } @@ -1344,23 +1390,23 @@ private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos BinaryOp bright = (BinaryOp)right; Hop right1 = bright.getInput().get(0); Hop right2 = bright.getInput().get(1); - + if( right2 instanceof LiteralOp && - HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 && - right1 == left && bright.getOp() == OpOp2.GREATER ) + HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 && + right1 == left && bright.getOp() == OpOp2.GREATER ) { BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX); HopRewriteUtils.replaceChildReference(parent, bop, binary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = binary; applied= true; - + LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b"); } } } } - + return hi; } @@ -1373,68 +1419,68 @@ private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) { Hop left = hi2.getInput().get(0); Hop right = hi2.getInput().get(1); - + //create new operators (incl refresh size inside for transpose) ReorgOp trans = HopRewriteUtils.createTranspose(right); BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT); AggUnaryOp sum = HopRewriteUtils.createSum(mult); - + //rehang new subdag under parent node HopRewriteUtils.replaceChildReference(parent, hi, sum, pos); HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = sum; - + LOG.debug("Applied simplifyTraceMatrixMult"); - } + } } - + return hi; } - - private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) + + private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) { //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] - if( hi instanceof IndexingOp - && ((IndexingOp)hi).isRowLowerEqualsUpper() - && ((IndexingOp)hi).isColLowerEqualsUpper() - && hi.getInput().get(0).getParent().size()==1 //rix is single mm consumer - && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) ) + if( hi instanceof IndexingOp + && ((IndexingOp)hi).isRowLowerEqualsUpper() + && ((IndexingOp)hi).isColLowerEqualsUpper() + && hi.getInput().get(0).getParent().size()==1 //rix is single mm consumer + && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) ) { Hop mm = hi.getInput().get(0); Hop X = mm.getInput().get(0); Hop Y = mm.getInput().get(1); Hop rowExpr = hi.getInput().get(1); //rl==ru Hop colExpr = hi.getInput().get(3); //cl==cu - + HopRewriteUtils.removeAllChildReferences(mm); - + //create new indexing operations - IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.FP64, X, + IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.FP64, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false); ix1.setBlocksize(X.getBlocksize()); ix1.refreshSizeInformation(); - IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.FP64, Y, + IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.FP64, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true); ix2.setBlocksize(Y.getBlocksize()); ix2.refreshSizeInformation(); - + //rewire matrix mult over ix1 and ix2 HopRewriteUtils.addChildReference(mm, ix1, 0); HopRewriteUtils.addChildReference(mm, ix2, 1); mm.refreshSizeInformation(); - + hi = mm; - + LOG.debug("Applied simplifySlicedMatrixMult"); } - + return hi; } - + private static Hop simplifyListIndexing(Hop hi) { //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1] if( hi instanceof IndexingOp && hi.getDataType().isList() - && !(hi.getInput(4) instanceof LiteralOp) ) + && !(hi.getInput(4) instanceof LiteralOp) ) { HopRewriteUtils.replaceChildReference(hi, hi.getInput(4), new LiteralOp(1)); LOG.debug("Applied simplifyListIndexing (line "+hi.getBeginLine()+")."); @@ -1442,17 +1488,17 @@ private static Hop simplifyListIndexing(Hop hi) { return hi; } - private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos) + private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos) { //order(matrix(7), indexreturn=FALSE) -> matrix(7) //order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1) if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order { Hop hi2 = hi.getInput().get(0); - + if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==OpOpDG.RAND - && ((DataGenOp)hi2).hasConstantValue() - && hi.getInput().get(3) instanceof LiteralOp ) //known indexreturn + && ((DataGenOp)hi2).hasConstantValue() + && hi.getInput().get(3) instanceof LiteralOp ) //known indexreturn { if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) { @@ -1462,7 +1508,7 @@ private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos) HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = seq; - + LOG.debug("Applied simplifyConstantSort1."); } else @@ -1471,30 +1517,30 @@ private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos) HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = hi2; - + LOG.debug("Applied simplifyConstantSort2."); } - } + } } - + return hi; } - - private static Hop simplifyOrderedSort(Hop parent, Hop hi, int pos) + + private static Hop simplifyOrderedSort(Hop parent, Hop hi, int pos) { //order(seq(2,N+1,1), indexreturn=FALSE) -> matrix(7) //order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1) if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order { Hop hi2 = hi.getInput().get(0); - + if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==OpOpDG.SEQ ) { Hop incr = hi2.getInput().get(((DataGenOp)hi2).getParamIndex(Statement.SEQ_INCR)); //check for known ascending ordering and known indexreturn if( incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)incr)==1 - && hi.getInput().get(2) instanceof LiteralOp //decreasing - && hi.getInput().get(3) instanceof LiteralOp ) //indexreturn + && hi.getInput().get(2) instanceof LiteralOp //decreasing + && hi.getInput().get(3) instanceof LiteralOp ) //indexreturn { if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) //IXRET, ASC/DESC { @@ -1505,7 +1551,7 @@ private static Hop simplifyOrderedSort(Hop parent, Hop hi, int pos) HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = seq; - + LOG.debug("Applied simplifyOrderedSort1."); } else if( !HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //DATA, ASC @@ -1514,44 +1560,44 @@ else if( !HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //D HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = hi2; - + LOG.debug("Applied simplifyOrderedSort2."); } } } } - + return hi; } - private static Hop fuseOrderOperationChain(Hop hi) + private static Hop fuseOrderOperationChain(Hop hi) { //order(order(X,2),1) -> order(X, (12)), if( HopRewriteUtils.isReorg(hi, ReOrgOp.SORT) - && hi.getInput().get(1) instanceof LiteralOp //scalar by - && hi.getInput().get(2) instanceof LiteralOp //scalar desc - && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret - { + && hi.getInput().get(1) instanceof LiteralOp //scalar by + && hi.getInput().get(2) instanceof LiteralOp //scalar desc + && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret + { LiteralOp by = (LiteralOp) hi.getInput().get(1); boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)); - + //find chain of order operations with same desc/ixret configuration and single consumers Set probe = new HashSet<>(); ArrayList byList = new ArrayList<>(); byList.add(by); probe.add(by.getStringValue()); Hop input = hi.getInput().get(0); while( HopRewriteUtils.isReorg(input, ReOrgOp.SORT) - && input.getInput().get(1) instanceof LiteralOp //scalar by - && !probe.contains(input.getInput().get(1).getName()) - && HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc) - && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) - && input.getParent().size() == 1 ) + && input.getInput().get(1) instanceof LiteralOp //scalar by + && !probe.contains(input.getInput().get(1).getName()) + && HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc) + && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) + && input.getParent().size() == 1 ) { byList.add((LiteralOp)input.getInput().get(1)); probe.add(input.getInput().get(1).getName()); input = input.getInput().get(0); } - + //merge order chain if at least two instances if( byList.size() >= 2 ) { //create new order operations @@ -1561,7 +1607,7 @@ private static Hop fuseOrderOperationChain(Hop hi) inputs.add(new LiteralOp(desc)); inputs.add(new LiteralOp(false)); Hop hnew = HopRewriteUtils.createReorg(inputs, ReOrgOp.SORT); - + //cleanup references recursively Hop current = hi; while(current != input ) { @@ -1569,86 +1615,86 @@ private static Hop fuseOrderOperationChain(Hop hi) HopRewriteUtils.removeAllChildReferences(current); current = tmp; } - + //rewire all parents (avoid anomalies with replicated datagen) List parents = new ArrayList<>(hi.getParent()); for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, hi, hnew); - + hi = hnew; LOG.debug("Applied fuseOrderOperationChain (line "+hi.getBeginLine()+")."); } } - + return hi; } - + /** * Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C) - * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ - private static Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos) + private static Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos) { if( HopRewriteUtils.isTransposeOperation(hi) - && hi.getInput().get(0) instanceof BinaryOp //basic binary - && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) + && hi.getInput().get(0) instanceof BinaryOp //basic binary + && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) { Hop left = hi.getInput().get(0).getInput().get(0); Hop C = hi.getInput().get(0).getInput().get(1); - + //check matrix mult and both inputs transposes w/ single consumer if( left instanceof AggBinaryOp && C.getDataType().isMatrix() - && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) - && left.getInput().get(0).getParent().size()==1 - && HopRewriteUtils.isTransposeOperation(left.getInput().get(1)) - && left.getInput().get(1).getParent().size()==1 ) + && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) + && left.getInput().get(0).getParent().size()==1 + && HopRewriteUtils.isTransposeOperation(left.getInput().get(1)) + && left.getInput().get(1).getParent().size()==1 ) { Hop A = left.getInput().get(0).getInput().get(0); Hop B = left.getInput().get(1).getInput().get(0); - + AggBinaryOp abop = HopRewriteUtils.createMatrixMultiply(B, A); ReorgOp rop = HopRewriteUtils.createTranspose(C); BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS); - + HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); - + hi = bop; LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+")."); } } - + return hi; } - + // Patterns: X + (X==0) * s -> replace(X, 0, s) - private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int pos) + private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int pos) { if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) && hi.getInput().get(0).isMatrix() - && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) - && hi.getInput().get(1).getInput().get(1).isScalar() - && HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0), OpOp2.EQUAL, 0) - && hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0)) ) + && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) + && hi.getInput().get(1).getInput().get(1).isScalar() + && HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0), OpOp2.EQUAL, 0) + && hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0)) ) { LinkedHashMap args = new LinkedHashMap<>(); args.put("target", hi.getInput().get(0)); args.put("pattern", new LiteralOp(0)); args.put("replacement", hi.getInput().get(1).getInput().get(1)); Hop replace = HopRewriteUtils.createParameterizedBuiltinOp( - hi.getInput().get(0), args, ParamBuiltinOp.REPLACE); + hi.getInput().get(0), args, ParamBuiltinOp.REPLACE); HopRewriteUtils.replaceChildReference(parent, hi, replace, pos); hi = replace; LOG.debug("Applied simplifyReplaceZeroOperation (line "+hi.getBeginLine()+")."); } return hi; } - + /** * Pattners: t(t(X)) -> X, rev(rev(X)) -> X - * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position @@ -1657,7 +1703,7 @@ private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int pos) private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) { ReOrgOp[] lookup = new ReOrgOp[]{ReOrgOp.TRANS, ReOrgOp.REV}; - + if( hi instanceof ReorgOp && HopRewriteUtils.isValidOp(((ReorgOp)hi).getOp(), lookup) ) //first reorg { ReOrgOp firstOp = ((ReorgOp)hi).getOp(); @@ -1669,15 +1715,15 @@ private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos); HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = hi3; - + LOG.debug("Applied removeUnecessaryReorgOperation."); } } - + return hi; } - - /* + + /* * Eliminate RemoveEmpty for SUM, SUM_SQ, and NNZ (number of non-zeros) */ private static Hop removeUnnecessaryRemoveEmpty(Hop parent, Hop hi, int pos) @@ -1688,14 +1734,14 @@ private static Hop removeUnnecessaryRemoveEmpty(Hop parent, Hop hi, int pos) //rowSums(removeEmpty(target=X,margin="cols")) -> rowSums(X) //colSums(removeEmpty(target=X,margin="rows")) -> colSums(X) if( (HopRewriteUtils.isSum(hi) || HopRewriteUtils.isSumSq(hi)) - && HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0)) - && hi.getInput().get(0).getParent().size() == 1 ) + && HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0)) + && hi.getInput().get(0).getParent().size() == 1 ) { AggUnaryOp agg = (AggUnaryOp)hi; ParameterizedBuiltinOp rmEmpty = (ParameterizedBuiltinOp) hi.getInput().get(0); boolean needRmEmpty = (agg.getDirection() == Direction.Row && HopRewriteUtils.isRemoveEmpty(rmEmpty, true)) - || (agg.getDirection() == Direction.Col && HopRewriteUtils.isRemoveEmpty(rmEmpty, false)); - + || (agg.getDirection() == Direction.Col && HopRewriteUtils.isRemoveEmpty(rmEmpty, false)); + if (rmEmpty.getParameterHop("select") == null && !needRmEmpty) { Hop input = rmEmpty.getTargetHop(); if( input != null ) { @@ -1704,11 +1750,11 @@ private static Hop removeUnnecessaryRemoveEmpty(Hop parent, Hop hi, int pos) } } } - + //check if nrow is called on the output of removeEmpty if( HopRewriteUtils.isUnary(hi, OpOp1.NROW) - && HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0), true) - && hi.getInput().get(0).getParent().size() == 1 ) + && HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0), true) + && hi.getInput().get(0).getParent().size() == 1 ) { ParameterizedBuiltinOp rm = (ParameterizedBuiltinOp) hi.getInput().get(0); //obtain optional select vector or input if col vector @@ -1718,9 +1764,9 @@ private static Hop removeUnnecessaryRemoveEmpty(Hop parent, Hop hi, int pos) //NOTE: part of static rewrites despite size dependence for phase //ordering before rewrite for DAG splits after table/removeEmpty Hop input = (rm.getParameterHop("select") != null) ? - rm.getParameterHop("select") : - (rm.getDim2() == 1) ? rm.getTargetHop() : null; - + rm.getParameterHop("select") : + (rm.getDim2() == 1) ? rm.getTargetHop() : null; + //create new expression w/o rmEmpty if applicable if( input != null ) { HopRewriteUtils.removeAllChildReferences(rm); @@ -1734,32 +1780,32 @@ private static Hop removeUnnecessaryRemoveEmpty(Hop parent, Hop hi, int pos) } } } - + return hi; } - private static Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos) + private static Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos) { - if( hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp - && ((BinaryOp)hi).getOp()==OpOp2.MINUS //first minus - && hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 ) + if( hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp + && ((BinaryOp)hi).getOp()==OpOp2.MINUS //first minus + && hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 ) { Hop hi2 = hi.getInput().get(1); - if( hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp - && ((BinaryOp)hi2).getOp()==OpOp2.MINUS //second minus - && hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 ) - + if( hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp + && ((BinaryOp)hi2).getOp()==OpOp2.MINUS //second minus + && hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 ) + { Hop hi3 = hi2.getInput().get(1); //remove unnecessary chain of -(-()) HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos); HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = hi3; - + LOG.debug("Applied removeUnecessaryMinus"); } } - + return hi; } @@ -1768,148 +1814,148 @@ private static Hop simplifyGroupedAggregate(Hop hi) if( hi instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hi).getOp()==ParamBuiltinOp.GROUPEDAGG ) //aggregate { ParameterizedBuiltinOp phi = (ParameterizedBuiltinOp)hi; - + if( phi.isCountFunction() //aggregate(fn="count") - && phi.getTargetHop().getDim2()==1 ) //only for vector + && phi.getTargetHop().getDim2()==1 ) //only for vector { HashMap params = phi.getParamIndexMap(); int ix1 = params.get(Statement.GAGG_TARGET); int ix2 = params.get(Statement.GAGG_GROUPS); - + //check for unnecessary memory consumption for "count" - if( ix1 != ix2 && phi.getInput().get(ix1)!=phi.getInput().get(ix2) ) + if( ix1 != ix2 && phi.getInput().get(ix1)!=phi.getInput().get(ix2) ) { Hop th = phi.getInput().get(ix1); Hop gh = phi.getInput().get(ix2); - + HopRewriteUtils.replaceChildReference(hi, th, gh, ix1); - - LOG.debug("Applied simplifyGroupedAggregateCount"); + + LOG.debug("Applied simplifyGroupedAggregateCount"); } } } - + return hi; } - - private static Hop fuseMinusNzBinaryOperation(Hop parent, Hop hi, int pos) + + private static Hop fuseMinusNzBinaryOperation(Hop parent, Hop hi, int pos) { //pattern X - (s * ppred(X,0,!=)) -> X -nz s //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate for X - tmp if X is sparse if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS) - && hi.getInput().get(0).getDataType()==DataType.MATRIX - && hi.getInput().get(1).getDataType()==DataType.MATRIX - && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) ) + && hi.getInput().get(0).getDataType()==DataType.MATRIX + && hi.getInput().get(1).getDataType()==DataType.MATRIX + && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) ) { Hop X = hi.getInput().get(0); Hop s = hi.getInput().get(1).getInput().get(0); Hop pred = hi.getInput().get(1).getInput().get(1); - + if( s.getDataType()==DataType.SCALAR && pred.getDataType()==DataType.MATRIX - && HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) - && pred.getInput().get(0) == X //depend on common subexpression elimination - && pred.getInput().get(1) instanceof LiteralOp - && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) + && HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) + && pred.getInput().get(0) == X //depend on common subexpression elimination + && pred.getInput().get(1) instanceof LiteralOp + && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { - Hop hnew = HopRewriteUtils.createBinary(X, s, OpOp2.MINUS_NZ); - + Hop hnew = HopRewriteUtils.createBinary(X, s, OpOp2.MINUS_NZ); + //relink new hop into original position HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; - - LOG.debug("Applied fuseMinusNzBinaryOperation (line "+hi.getBeginLine()+")"); - } + + LOG.debug("Applied fuseMinusNzBinaryOperation (line "+hi.getBeginLine()+")"); + } } - + return hi; } - - private static Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos) + + private static Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos) { //pattern ppred(X,0,"!=")*log(X) -> log_nz(X) //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate and to prevent dense intermediates if X is ultra sparse if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) - && hi.getInput().get(0).getDataType()==DataType.MATRIX - && hi.getInput().get(1).getDataType()==DataType.MATRIX - && HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG) ) + && hi.getInput().get(0).getDataType()==DataType.MATRIX + && hi.getInput().get(1).getDataType()==DataType.MATRIX + && HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG) ) { Hop pred = hi.getInput().get(0); Hop X = hi.getInput().get(1).getInput().get(0); - + if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) - && pred.getInput().get(0) == X //depend on common subexpression elimination - && pred.getInput().get(1) instanceof LiteralOp - && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) + && pred.getInput().get(0) == X //depend on common subexpression elimination + && pred.getInput().get(1) instanceof LiteralOp + && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { Hop hnew = HopRewriteUtils.createUnary(X, OpOp1.LOG_NZ); - + //relink new hop into original position HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; - - LOG.debug("Applied fuseLogNzUnaryOperation (line "+hi.getBeginLine()+")."); - } + + LOG.debug("Applied fuseLogNzUnaryOperation (line "+hi.getBeginLine()+")."); + } } - + return hi; } - private static Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int pos) + private static Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int pos) { //pattern ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate and to prevent dense intermediates if X is ultra sparse if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) - && hi.getInput().get(0).getDataType()==DataType.MATRIX - && hi.getInput().get(1).getDataType()==DataType.MATRIX - && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG) ) + && hi.getInput().get(0).getDataType()==DataType.MATRIX + && hi.getInput().get(1).getDataType()==DataType.MATRIX + && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG) ) { Hop pred = hi.getInput().get(0); Hop X = hi.getInput().get(1).getInput().get(0); Hop log = hi.getInput().get(1).getInput().get(1); - + if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) - && pred.getInput().get(0) == X //depend on common subexpression elimination - && pred.getInput().get(1) instanceof LiteralOp - && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) + && pred.getInput().get(0) == X //depend on common subexpression elimination + && pred.getInput().get(1) instanceof LiteralOp + && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { Hop hnew = HopRewriteUtils.createBinary(X, log, OpOp2.LOG_NZ); - + //relink new hop into original position HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; - - LOG.debug("Applied fuseLogNzBinaryOperation (line "+hi.getBeginLine()+")"); - } + + LOG.debug("Applied fuseLogNzBinaryOperation (line "+hi.getBeginLine()+")"); + } } - + return hi; } - private static Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos) + private static Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos) { //pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) //note: this rewrite supports both left/right sequence - + if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp)hi).isOuter() ) { if( ( HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a: outer(v, t(seq(1,m)), "==") - && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) - || HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b: outer(seq(1,m), t(v) "==") + && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) + || HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b: outer(seq(1,m), t(v) "==") { //determine variable parameters for pattern a/b boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)); boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)); - Hop trgt = isPatternB ? (isTransposeRight ? + Hop trgt = isPatternB ? (isTransposeRight ? hi.getInput().get(1).getInput().get(0) : //get v from t(v) HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v') hi.getInput().get(0); //get v directly Hop seq = isPatternB ? hi.getInput().get(0) : hi.getInput().get(1).getInput().get(0); String direction = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols"; - + //setup input parameter hops LinkedHashMap inputargs = new LinkedHashMap<>(); inputargs.put("target", trgt); @@ -1917,34 +1963,34 @@ private static Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos) inputargs.put("dir", new LiteralOp(direction)); inputargs.put("ignore", new LiteralOp(true)); inputargs.put("cast", new LiteralOp(false)); - + //create new hop ParameterizedBuiltinOp pbop = HopRewriteUtils - .createParameterizedBuiltinOp(trgt, inputargs, ParamBuiltinOp.REXPAND); - + .createParameterizedBuiltinOp(trgt, inputargs, ParamBuiltinOp.REXPAND); + //relink new hop into original position HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); hi = pbop; - + LOG.debug("Applied simplifyOuterSeqExpand (line "+hi.getBeginLine()+")"); } } - + return hi; } - + private static Hop simplifyBinaryComparisonChain(Hop parent, Hop hi, int pos) { - if( HopRewriteUtils.isBinaryPPred(hi) - && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 0d, 1d) - && HopRewriteUtils.isBinaryPPred(hi.getInput().get(0)) ) + if( HopRewriteUtils.isBinaryPPred(hi) + && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 0d, 1d) + && HopRewriteUtils.isBinaryPPred(hi.getInput().get(0)) ) { BinaryOp bop = (BinaryOp) hi; BinaryOp bop2 = (BinaryOp) hi.getInput().get(0); boolean one = HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 1); - + //pattern: outer(v1,v2,"!=") == 1 -> outer(v1,v2,"!=") if( (one && bop.getOp() == OpOp2.EQUAL) - || (!one && bop.getOp() == OpOp2.NOTEQUAL) ) + || (!one && bop.getOp() == OpOp2.NOTEQUAL) ) { HopRewriteUtils.replaceChildReference(parent, bop, bop2, pos); HopRewriteUtils.cleanupUnreferenced(bop); @@ -1955,62 +2001,62 @@ private static Hop simplifyBinaryComparisonChain(Hop parent, Hop hi, int pos) { else if( !one && bop.getOp() == OpOp2.EQUAL ) { OpOp2 optr = bop2.getComplementPPredOperation(); BinaryOp tmp = HopRewriteUtils.createBinary(bop2.getInput().get(0), - bop2.getInput().get(1), optr, bop2.isOuter()); + bop2.getInput().get(1), optr, bop2.isOuter()); HopRewriteUtils.replaceChildReference(parent, bop, tmp, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2); hi = tmp; LOG.debug("Applied simplifyBinaryComparisonChain0 (line "+hi.getBeginLine()+")"); } } - + return hi; } - + private static Hop simplifyCumsumColOrFullAggregates(Hop hi) { //pattern: colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) if( (HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.Col) - || HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol)) - && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM) - && hi.getInput().get(0).getParent().size()==1) + || HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol)) + && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM) + && hi.getInput().get(0).getParent().size()==1) { Hop cumsumX = hi.getInput().get(0); Hop X = cumsumX.getInput().get(0); Hop mult = HopRewriteUtils.createBinary(X, - HopRewriteUtils.createSeqDataGenOp(X, false), OpOp2.MULT); + HopRewriteUtils.createSeqDataGenOp(X, false), OpOp2.MULT); HopRewriteUtils.replaceChildReference(hi, cumsumX, mult); HopRewriteUtils.removeAllChildReferences(cumsumX); LOG.debug("Applied simplifyCumsumColOrFullAggregates (line "+hi.getBeginLine()+")"); } return hi; } - + private static Hop simplifyCumsumReverse(Hop parent, Hop hi, int pos) { //pattern: rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X) if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV) - && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM) - && hi.getInput().get(0).getParent().size()==1 - && HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV) - && hi.getInput().get(0).getInput().get(0).getParent().size()==1) + && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM) + && hi.getInput().get(0).getParent().size()==1 + && HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV) + && hi.getInput().get(0).getInput().get(0).getParent().size()==1) { Hop cumsumX = hi.getInput().get(0); Hop revX = cumsumX.getInput().get(0); Hop X = revX.getInput().get(0); Hop plus = HopRewriteUtils.createBinary(X, HopRewriteUtils - .createAggUnaryOp(X, AggOp.SUM, Direction.Col), OpOp2.PLUS); + .createAggUnaryOp(X, AggOp.SUM, Direction.Col), OpOp2.PLUS); Hop minus = HopRewriteUtils.createBinary(plus, - HopRewriteUtils.createUnary(X, OpOp1.CUMSUM), OpOp2.MINUS); + HopRewriteUtils.createUnary(X, OpOp1.CUMSUM), OpOp2.MINUS); HopRewriteUtils.replaceChildReference(parent, hi, minus, pos); HopRewriteUtils.cleanupUnreferenced(hi, cumsumX, revX); - + hi = minus; LOG.debug("Applied simplifyCumsumReverse (line "+hi.getBeginLine()+")"); } return hi; } - + private static Hop simplifyNotOverComparisons(Hop parent, Hop hi, int pos){ if(HopRewriteUtils.isUnary(hi, OpOp1.NOT) && hi.getInput(0) instanceof BinaryOp - && hi.getInput(0).getParent().size() == 1) //NOT is only consumer + && hi.getInput(0).getParent().size() == 1) //NOT is only consumer { Hop binaryOperator = hi.getInput(0); Hop A = binaryOperator.getInput(0); @@ -2041,66 +2087,66 @@ else if(HopRewriteUtils.isBinary(binaryOperator, OpOp2.EQUAL)) { return hi; } - + private static Hop fixNonScalarPrint(Hop parent, Hop hi, int pos) { if(HopRewriteUtils.isUnary(parent, OpOp1.PRINT) && !hi.getDataType().isScalar()) { LinkedHashMap args = new LinkedHashMap<>(); args.put("target", hi); Hop newHop = HopRewriteUtils.createParameterizedBuiltinOp( - hi, args, ParamBuiltinOp.TOSTRING); + hi, args, ParamBuiltinOp.TOSTRING); HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos); hi = newHop; LOG.debug("Applied fixNonScalarPrint (line " + hi.getBeginLine() + ")"); } - + return hi; } - + /** * NOTE: currently disabled since this rewrite is INVALID in the * presence of NaNs (because (NaN!=NaN) is true). - * + * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ @SuppressWarnings("unused") - private static Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos) + private static Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos) { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); - + Hop datagen = null; - + //ppred(X,X,"==") -> matrix(1, rows=nrow(X),cols=nrow(Y)) if( left==right && bop.getOp()==OpOp2.EQUAL || bop.getOp()==OpOp2.GREATEREQUAL || bop.getOp()==OpOp2.LESSEQUAL ) datagen = HopRewriteUtils.createDataGenOp(left, 1); - + //ppred(X,X,"!=") -> matrix(0, rows=nrow(X),cols=nrow(Y)) if( left==right && bop.getOp()==OpOp2.NOTEQUAL || bop.getOp()==OpOp2.GREATER || bop.getOp()==OpOp2.LESS ) datagen = HopRewriteUtils.createDataGenOp(left, 0); - + if( datagen != null ) { HopRewriteUtils.replaceChildReference(parent, hi, datagen, pos); hi = datagen; } } - + return hi; } - + private static void removeTWriteTReadPairs(ArrayList roots) { Iterator iter = roots.iterator(); while(iter.hasNext()) { Hop root = iter.next(); if( HopRewriteUtils.isData(root, OpOpData.TRANSIENTWRITE) - && HopRewriteUtils.isData(root.getInput(0), OpOpData.TRANSIENTREAD) - && root.getName().equals(root.getInput(0).getName()) - && !root.getInput(0).requiresCheckpoint()) + && HopRewriteUtils.isData(root.getInput(0), OpOpData.TRANSIENTREAD) + && root.getName().equals(root.getInput(0).getName()) + && !root.getInput(0).requiresCheckpoint()) { iter.remove(); } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java new file mode 100644 index 00000000000..afb70b8ff3f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java @@ -0,0 +1,76 @@ +package org.apache.sysds.test.functions.rewrite; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +public class RewriteBooleanSimplificationTest extends AutomatedTestBase { + + private static final String TEST_NAME_AND = "RewriteBooleanSimplificationTestAnd"; + private static final String TEST_NAME_OR = "RewriteBooleanSimplificationTestOr"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteBooleanSimplificationTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME_AND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_AND)); + addTestConfiguration(TEST_NAME_OR, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_OR)); + } + + @Test + public void testBooleanRewriteAnd() { + testRewriteBooleanSimplification(TEST_NAME_AND, ExecType.CP, 0.0); + } + + @Test + public void testBooleanRewriteOr() { + testRewriteBooleanSimplification(TEST_NAME_OR, ExecType.CP, 1.0); + } + + private void testRewriteBooleanSimplification(String testname, ExecType et, double expected) { + ExecMode platformOld = rtplatform; + rtplatform = (et == ExecType.SPARK) ? ExecMode.SPARK : ExecMode.HYBRID; + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if (rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{}; + + runTest(true, false, null, -1); + + Assert.assertEquals("Expected boolean simplification result does not match", expected, getRewriteBooleanSimplificationResult(testname), 0.0001); + } finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private double getRewriteBooleanSimplificationResult(String testname) { + + if (testname.equals(TEST_NAME_AND)) { + // a & !a simplifies to false (0.0) + return 0.0; + } else if (testname.equals(TEST_NAME_OR)) { + // a | !a simplifies to true (1.0) + return 1.0; + } else { + // In case of an unknown operation, we return a default value (e.g., 0.0). + return 0.0; + } + } + +} diff --git a/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml new file mode 100644 index 00000000000..82cf47adee5 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml @@ -0,0 +1,5 @@ +a = 1; # true +b = 0; # false + +result = (a & !a); # Expected result: false (0.0) + diff --git a/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml new file mode 100644 index 00000000000..fdc4172b4ab --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml @@ -0,0 +1,5 @@ +a = 1; # true +b = 0; # false + +result = (a | !a); # Expected result: true (1.0) + From 81bf70d07b1c16ee119efdadd73b6b94f5761915 Mon Sep 17 00:00:00 2001 From: aarna Date: Thu, 7 Nov 2024 18:43:52 +0530 Subject: [PATCH 2/5] Boolean Rewrite Task --- .../functions/rewrite/RewriteBooleanSimplificationTestAnd.dml | 1 - .../functions/rewrite/RewriteBooleanSimplificationTestOr.dml | 1 - 2 files changed, 2 deletions(-) diff --git a/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml index 82cf47adee5..85aa285624f 100644 --- a/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml +++ b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestAnd.dml @@ -1,5 +1,4 @@ a = 1; # true -b = 0; # false result = (a & !a); # Expected result: false (0.0) diff --git a/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml index fdc4172b4ab..0ebb244acc8 100644 --- a/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml +++ b/src/test/scripts/functions/rewrite/RewriteBooleanSimplificationTestOr.dml @@ -1,5 +1,4 @@ a = 1; # true -b = 0; # false result = (a | !a); # Expected result: true (1.0) From cd3aec321a1307e7e7b6dc76f3ac889bdf76d537 Mon Sep 17 00:00:00 2001 From: aarna Date: Thu, 6 Mar 2025 20:44:38 +0100 Subject: [PATCH 3/5] implemented reverse sequence step test and made changes to reorg operation function (phase ordering) in RewriteAlgebraicSimplificationStatic. --- .../RewriteAlgebraicSimplificationStatic.java | 85 ++++++++++++++ ...ewriteSimplifyReverseSequenceStepTest.java | 109 ++++++++++++++++++ .../RewriteSimplifyReverseSequenceStep.dml | 35 ++++++ 3 files changed, 229 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 5d867bf0ffb..5e79f73502e 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -156,6 +156,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) hi = simplifyReverseSequence(hop, hi, i); //e.g., rev(seq(1,n)) -> seq(n,1) + hi = simplifyReverseSequenceStep(hop, hi, i); if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X @@ -209,6 +210,59 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hop.setVisited(); } + private static Hop simplifyReverseSequenceStep(Hop parent, Hop hi, int pos) { + if (HopRewriteUtils.isReorg(hi, ReOrgOp.REV) + && hi.getInput(0) instanceof DataGenOp + && ((DataGenOp) hi.getInput(0)).getOp() == OpOpDG.SEQ + && hi.getInput(0).getParent().size() == 1) { // only one consumer + + DataGenOp seq = (DataGenOp) hi.getInput(0); + Hop from = seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM)); + Hop to = seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO)); + Hop incr = seq.getInput().get(seq.getParamIndex(Statement.SEQ_INCR)); + + if (from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp) { + double fromVal = ((LiteralOp) from).getDoubleValue(); + double toVal = ((LiteralOp) to).getDoubleValue(); + double incrVal = ((LiteralOp) incr).getDoubleValue(); + + // Skip if increment is zero (invalid sequence) + if (Math.abs(incrVal) < 1e-10) + return hi; + + boolean isValidDirection = false; + + // Checking direction compatibility + if ((incrVal > 0 && fromVal <= toVal) || (incrVal < 0 && fromVal >= toVal)) { + isValidDirection = true; + } + + if (isValidDirection) { + // Calculate the number of elements and the last element + int numValues = (int)Math.floor(Math.abs((toVal - fromVal) / incrVal)) + 1; + double lastVal = fromVal + (numValues - 1) * incrVal; + + // Create a new sequence based on actual last value + LiteralOp newFrom = new LiteralOp(lastVal); + LiteralOp newTo = new LiteralOp(fromVal); + LiteralOp newIncr = new LiteralOp(-incrVal); + + // Replace the parameters + seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), newFrom); + seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), newTo); + seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), newIncr); + + // Replace the old sequence with the new one + HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); + HopRewriteUtils.cleanupUnreferenced(hi, seq); + hi = seq; + LOG.debug("Applied simplifyReverseSequenceStep (line " + hi.getBeginLine() + ")."); + } + } + } + return hi; + } + private static Hop removeUnnecessaryVectorizeOperation(Hop hi) { //applies to all binary matrix operations, if one input is unnecessarily vectorized @@ -1853,6 +1907,37 @@ private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) LOG.debug("Applied removeUnecessaryReorgOperation."); } } + // Handle the second case: t(X) %*% v -> t(t(v) %*% X) + else if (hi instanceof BinaryOp && ((BinaryOp) hi).getOp() == OpOp2.MULT) { + Hop left = hi.getInput().get(0); + Hop right = hi.getInput().get(1); + + if (left instanceof ReorgOp && ((ReorgOp) left).getOp() == ReOrgOp.TRANS) { + try { + Hop X = left.getInput().get(0); + + // Create transpose of v + Hop transposeV = HopRewriteUtils.createTranspose(right); + + // Create multiplication + Hop newMult = HopRewriteUtils.createMatrixMultiply(transposeV, X); + + // Create final transpose + Hop finalTranspose = HopRewriteUtils.createTranspose(newMult); + + // Replace the original hop with new construct + HopRewriteUtils.replaceChildReference(parent, hi, finalTranspose, pos); + HopRewriteUtils.cleanupUnreferenced(hi); + + LOG.debug("Applied removeUnnecessaryReorgOperation."); + + return finalTranspose; + } + catch (Exception e) { + LOG.error("Failed to apply removeUnnecessaryReorgOperation: " + e.getMessage(), e); + } + } + } return hi; } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java new file mode 100644 index 00000000000..719bde2c6d2 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class RewriteSimplifyReverseSequenceStepTest extends AutomatedTestBase { + private static final String TEST_NAME1 = "RewriteSimplifyReverseSequenceStep"; + + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyReverseSequenceStepTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"})); + } + + @Test + public void testRewriteReverseSeqStep() { + testRewriteReverseSeq(TEST_NAME1, true); + } + + @Test + public void testNoRewriteReverseSeqStep() { + testRewriteReverseSeq(TEST_NAME1, false); + } + + private void testRewriteReverseSeq(String testname, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + int rows = 10; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-stats", "-args", String.valueOf(rows), output("Scalar")}; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + + // Calculate expected sums for each sequence + double sum1 = calculateSum(0, rows-1, 1); // A1 = rev(seq(0, rows-1, 1)) + double sum2 = calculateSum(0, rows, 2); // A2 = rev(seq(0, rows, 2)) + double sum3 = calculateSum(2, rows, 2); // A3 = rev(seq(2, rows, 2)) + double sum4 = calculateSum(0, 100, 5); // A4 = rev(seq(0, 100, 5)) + double sum5 = calculateSum(15, 5, -0.5); // A5 = rev(seq(15, 5, -0.5)) + + double expected = sum1 + sum2 + sum3 + sum4 + sum5; + + double ret = readDMLScalarFromOutputDir("Scalar").get(new MatrixValue.CellIndex(1, 1)).doubleValue(); + + Assert.assertEquals("Incorrect sum computed", expected, ret, 1e-10); + + if (rewrites) { + // With bidirectional rewrite, REV operations should be removed + Assert.assertFalse("Rewrite should have removed REV operation!", + heavyHittersContainsString("rev")); + } + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + + // Helper method to calculate sum of a sequence + private double calculateSum(double from, double to, double incr) { + double sum = 0; + int n = 0; + + if ((incr > 0 && from <= to) || (incr < 0 && from >= to)) { + // Calculate number of elements in the sequence + n = (int)Math.floor(Math.abs((to - from) / incr)) + 1; + + // Calculate the last element in the sequence + double last = from + (n - 1) * incr; + + // Use arithmetic sequence sum formula: n * (first + last) / 2 + sum = n * (from + last) / 2; + } + + return sum; + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml new file mode 100644 index 00000000000..e8f3314c265 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +rows = as.integer($1) + +# Original test sequences (positive increments) +A1 = rev(seq(0, rows-1, 1)) # Should become seq(rows-1, 0, -1) +A2 = rev(seq(0, rows, 2)) # Should become seq(rows, 0, -2) +A3 = rev(seq(2, rows, 2)) # Should become seq(lastVal, 2, -2) where lastVal is the last value in the sequence +A4 = rev(seq(0, 100, 5)) # Should become seq(100, 0, -5) +A5 = rev(seq(15, 5, -0.5)) # Should become seq(5, 15, 0.5) + +# Sum all sequences +R = sum(A1) + sum(A2) + sum(A3) + sum(A4) + sum(A5) + +# Output +write(R, $2) \ No newline at end of file From ce5350c5456e943e8f4b6aaede5a63eedb2d8986 Mon Sep 17 00:00:00 2001 From: aarna Date: Fri, 7 Mar 2025 00:01:44 +0100 Subject: [PATCH 4/5] removed unnecessary files --- .../RewriteBooleanSimplificationTest.java | 76 ------------------- 1 file changed, 76 deletions(-) delete mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java deleted file mode 100644 index afb70b8ff3f..00000000000 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteBooleanSimplificationTest.java +++ /dev/null @@ -1,76 +0,0 @@ -package org.apache.sysds.test.functions.rewrite; - -import org.junit.Assert; -import org.junit.Test; -import org.apache.sysds.api.DMLScript; -import org.apache.sysds.common.Types.ExecMode; -import org.apache.sysds.common.Types.ExecType; -import org.apache.sysds.test.AutomatedTestBase; -import org.apache.sysds.test.TestConfiguration; -import org.apache.sysds.test.TestUtils; - -public class RewriteBooleanSimplificationTest extends AutomatedTestBase { - - private static final String TEST_NAME_AND = "RewriteBooleanSimplificationTestAnd"; - private static final String TEST_NAME_OR = "RewriteBooleanSimplificationTestOr"; - private static final String TEST_DIR = "functions/rewrite/"; - private static final String TEST_CLASS_DIR = TEST_DIR + RewriteBooleanSimplificationTest.class.getSimpleName() + "/"; - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME_AND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_AND)); - addTestConfiguration(TEST_NAME_OR, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_OR)); - } - - @Test - public void testBooleanRewriteAnd() { - testRewriteBooleanSimplification(TEST_NAME_AND, ExecType.CP, 0.0); - } - - @Test - public void testBooleanRewriteOr() { - testRewriteBooleanSimplification(TEST_NAME_OR, ExecType.CP, 1.0); - } - - private void testRewriteBooleanSimplification(String testname, ExecType et, double expected) { - ExecMode platformOld = rtplatform; - rtplatform = (et == ExecType.SPARK) ? ExecMode.SPARK : ExecMode.HYBRID; - - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if (rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID) { - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - } - - try { - TestConfiguration config = getTestConfiguration(testname); - loadTestConfiguration(config); - - String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[]{}; - - runTest(true, false, null, -1); - - Assert.assertEquals("Expected boolean simplification result does not match", expected, getRewriteBooleanSimplificationResult(testname), 0.0001); - } finally { - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - } - } - - private double getRewriteBooleanSimplificationResult(String testname) { - - if (testname.equals(TEST_NAME_AND)) { - // a & !a simplifies to false (0.0) - return 0.0; - } else if (testname.equals(TEST_NAME_OR)) { - // a | !a simplifies to true (1.0) - return 1.0; - } else { - // In case of an unknown operation, we return a default value (e.g., 0.0). - return 0.0; - } - } - -} From 51a471aa42af1e23014b5ee2a90f2c3c4bc75e0c Mon Sep 17 00:00:00 2001 From: aarna Date: Fri, 7 Mar 2025 00:02:45 +0100 Subject: [PATCH 5/5] removed unnecessary files --- 1 | 1 - 1.mtd | 7 ------- 1000 | 1 - 1000.mtd | 7 ------- 4 files changed, 16 deletions(-) delete mode 100644 1 delete mode 100644 1.mtd delete mode 100644 1000 delete mode 100644 1000.mtd diff --git a/1 b/1 deleted file mode 100644 index 3bac2cd4489..00000000000 --- a/1 +++ /dev/null @@ -1 +0,0 @@ -1000.0 \ No newline at end of file diff --git a/1.mtd b/1.mtd deleted file mode 100644 index 6b3085d8e4f..00000000000 --- a/1.mtd +++ /dev/null @@ -1,7 +0,0 @@ -{ - "data_type": "scalar", - "value_type": "double", - "format": "text", - "author": "aarna", - "created": "2025-03-04 14:49:34 CET" -} \ No newline at end of file diff --git a/1000 b/1000 deleted file mode 100644 index 3bac2cd4489..00000000000 --- a/1000 +++ /dev/null @@ -1 +0,0 @@ -1000.0 \ No newline at end of file diff --git a/1000.mtd b/1000.mtd deleted file mode 100644 index 6b3085d8e4f..00000000000 --- a/1000.mtd +++ /dev/null @@ -1,7 +0,0 @@ -{ - "data_type": "scalar", - "value_type": "double", - "format": "text", - "author": "aarna", - "created": "2025-03-04 14:49:34 CET" -} \ No newline at end of file