diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 5d867bf0ffb..5e79f73502e 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -156,6 +156,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) hi = simplifyReverseSequence(hop, hi, i); //e.g., rev(seq(1,n)) -> seq(n,1) + hi = simplifyReverseSequenceStep(hop, hi, i); if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X @@ -209,6 +210,59 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) hop.setVisited(); } + private static Hop simplifyReverseSequenceStep(Hop parent, Hop hi, int pos) { + if (HopRewriteUtils.isReorg(hi, ReOrgOp.REV) + && hi.getInput(0) instanceof DataGenOp + && ((DataGenOp) hi.getInput(0)).getOp() == OpOpDG.SEQ + && hi.getInput(0).getParent().size() == 1) { // only one consumer + + DataGenOp seq = (DataGenOp) hi.getInput(0); + Hop from = seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM)); + Hop to = seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO)); + Hop incr = seq.getInput().get(seq.getParamIndex(Statement.SEQ_INCR)); + + if (from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp) { + double fromVal = ((LiteralOp) from).getDoubleValue(); + double toVal = ((LiteralOp) to).getDoubleValue(); + double incrVal = ((LiteralOp) incr).getDoubleValue(); + + // Skip if increment is zero (invalid sequence) + if (Math.abs(incrVal) < 1e-10) + return hi; + + boolean isValidDirection = false; + + // Checking direction compatibility + if ((incrVal > 0 && fromVal <= toVal) || (incrVal < 0 && fromVal >= toVal)) { + isValidDirection = true; + } + + if (isValidDirection) { + // Calculate the number of elements and the last element + int numValues = (int)Math.floor(Math.abs((toVal - fromVal) / incrVal)) + 1; + double lastVal = fromVal + (numValues - 1) * incrVal; + + // Create a new sequence based on actual last value + LiteralOp newFrom = new LiteralOp(lastVal); + LiteralOp newTo = new LiteralOp(fromVal); + LiteralOp newIncr = new LiteralOp(-incrVal); + + // Replace the parameters + seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), newFrom); + seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), newTo); + seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), newIncr); + + // Replace the old sequence with the new one + HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); + HopRewriteUtils.cleanupUnreferenced(hi, seq); + hi = seq; + LOG.debug("Applied simplifyReverseSequenceStep (line " + hi.getBeginLine() + ")."); + } + } + } + return hi; + } + private static Hop removeUnnecessaryVectorizeOperation(Hop hi) { //applies to all binary matrix operations, if one input is unnecessarily vectorized @@ -1853,6 +1907,37 @@ private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) LOG.debug("Applied removeUnecessaryReorgOperation."); } } + // Handle the second case: t(X) %*% v -> t(t(v) %*% X) + else if (hi instanceof BinaryOp && ((BinaryOp) hi).getOp() == OpOp2.MULT) { + Hop left = hi.getInput().get(0); + Hop right = hi.getInput().get(1); + + if (left instanceof ReorgOp && ((ReorgOp) left).getOp() == ReOrgOp.TRANS) { + try { + Hop X = left.getInput().get(0); + + // Create transpose of v + Hop transposeV = HopRewriteUtils.createTranspose(right); + + // Create multiplication + Hop newMult = HopRewriteUtils.createMatrixMultiply(transposeV, X); + + // Create final transpose + Hop finalTranspose = HopRewriteUtils.createTranspose(newMult); + + // Replace the original hop with new construct + HopRewriteUtils.replaceChildReference(parent, hi, finalTranspose, pos); + HopRewriteUtils.cleanupUnreferenced(hi); + + LOG.debug("Applied removeUnnecessaryReorgOperation."); + + return finalTranspose; + } + catch (Exception e) { + LOG.error("Failed to apply removeUnnecessaryReorgOperation: " + e.getMessage(), e); + } + } + } return hi; } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java new file mode 100644 index 00000000000..719bde2c6d2 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class RewriteSimplifyReverseSequenceStepTest extends AutomatedTestBase { + private static final String TEST_NAME1 = "RewriteSimplifyReverseSequenceStep"; + + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyReverseSequenceStepTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"})); + } + + @Test + public void testRewriteReverseSeqStep() { + testRewriteReverseSeq(TEST_NAME1, true); + } + + @Test + public void testNoRewriteReverseSeqStep() { + testRewriteReverseSeq(TEST_NAME1, false); + } + + private void testRewriteReverseSeq(String testname, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + int rows = 10; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-stats", "-args", String.valueOf(rows), output("Scalar")}; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + + // Calculate expected sums for each sequence + double sum1 = calculateSum(0, rows-1, 1); // A1 = rev(seq(0, rows-1, 1)) + double sum2 = calculateSum(0, rows, 2); // A2 = rev(seq(0, rows, 2)) + double sum3 = calculateSum(2, rows, 2); // A3 = rev(seq(2, rows, 2)) + double sum4 = calculateSum(0, 100, 5); // A4 = rev(seq(0, 100, 5)) + double sum5 = calculateSum(15, 5, -0.5); // A5 = rev(seq(15, 5, -0.5)) + + double expected = sum1 + sum2 + sum3 + sum4 + sum5; + + double ret = readDMLScalarFromOutputDir("Scalar").get(new MatrixValue.CellIndex(1, 1)).doubleValue(); + + Assert.assertEquals("Incorrect sum computed", expected, ret, 1e-10); + + if (rewrites) { + // With bidirectional rewrite, REV operations should be removed + Assert.assertFalse("Rewrite should have removed REV operation!", + heavyHittersContainsString("rev")); + } + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + + // Helper method to calculate sum of a sequence + private double calculateSum(double from, double to, double incr) { + double sum = 0; + int n = 0; + + if ((incr > 0 && from <= to) || (incr < 0 && from >= to)) { + // Calculate number of elements in the sequence + n = (int)Math.floor(Math.abs((to - from) / incr)) + 1; + + // Calculate the last element in the sequence + double last = from + (n - 1) * incr; + + // Use arithmetic sequence sum formula: n * (first + last) / 2 + sum = n * (from + last) / 2; + } + + return sum; + } +} diff --git a/src/test/scripts/functions/io/binary/dedupSerializedBlock.out b/src/test/scripts/functions/io/binary/dedupSerializedBlock.out new file mode 100644 index 00000000000..070604a6765 Binary files /dev/null and b/src/test/scripts/functions/io/binary/dedupSerializedBlock.out differ diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml new file mode 100644 index 00000000000..e8f3314c265 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +rows = as.integer($1) + +# Original test sequences (positive increments) +A1 = rev(seq(0, rows-1, 1)) # Should become seq(rows-1, 0, -1) +A2 = rev(seq(0, rows, 2)) # Should become seq(rows, 0, -2) +A3 = rev(seq(2, rows, 2)) # Should become seq(lastVal, 2, -2) where lastVal is the last value in the sequence +A4 = rev(seq(0, 100, 5)) # Should become seq(100, 0, -5) +A5 = rev(seq(15, 5, -0.5)) # Should become seq(5, 15, 0.5) + +# Sum all sequences +R = sum(A1) + sum(A2) + sum(A3) + sum(A4) + sum(A5) + +# Output +write(R, $2) \ No newline at end of file