From ddd7a9669ebd48307321166eaec4ac68c0222594 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Mon, 24 Feb 2025 14:47:20 +0100 Subject: [PATCH 01/16] added testing for removeUnnecessaryVectorizeOperation --- ...moveUnnecessaryVectorizeOperationTest.java | 106 ++++++++++++++++++ ...writeRemoveUnnecessaryVectorizeOperation.R | 45 ++++++++ ...iteRemoveUnnecessaryVectorizeOperation.dml | 37 ++++++ 3 files changed, 188 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperationTest.java new file mode 100644 index 00000000000..d231b0ff934 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperationTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRemoveUnnecessaryVectorizeOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteRemoveUnnecessaryVectorizeOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteRemoveUnnecessaryVectorizeOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 100; + private static final int cols = 100; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testRemoveUnnecessaryVectorizeOperationLeftNoRewrite() { + testRewriteRemoveUnnecessaryVectorizeOperation(1, false); + } + + @Test + public void testRemoveUnnecessaryVectorizeOperationLeftRewrite() { + testRewriteRemoveUnnecessaryVectorizeOperation(1, true); + } + + @Test + public void testRemoveUnnecessaryVectorizeOperationRightNoRewrite() { + testRewriteRemoveUnnecessaryVectorizeOperation(2, false); + } + + @Test + public void testRemoveUnnecessaryVectorizeOperationRightRewrite() { + testRewriteRemoveUnnecessaryVectorizeOperation(2, true); + } + + private void testRewriteRemoveUnnecessaryVectorizeOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, 1, 2, 1.00d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.RANDOM.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.RANDOM.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.R b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.R new file mode 100644 index 00000000000..dd409e29fcb --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix and type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +type = as.integer(args[2]) + +Y = matrix(1, nrow(X), ncol(X)) + +# Perform operations +if(type==1){ + R = Y/X # Left vectorized scalar +} else if (type==2){ + R = X/Y # Right vectorized scalar +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.dml b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.dml new file mode 100644 index 00000000000..93f3ddb8ff0 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X +X = read($1) +Y = matrix(1,nrow(X),ncol(X)) + +type = $2 + +# Perform operations +if(type==1){ + R = Y/X # Left vectorized scalar +} +else if(type==2){ + R = X/Y # Right vectorized scalar +} + +# Write the result matrix R +write(R, $3) From d386407c5a7af9e39324a419570a2c73f8a2a387 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Mon, 24 Feb 2025 16:14:12 +0100 Subject: [PATCH 02/16] added testing for removeUnnecessaryBinaryOperation --- ...eRemoveUnnecessaryBinaryOperationTest.java | 168 ++++++++++++++++++ .../RewriteRemoveUnnecessaryBinaryOperation.R | 51 ++++++ ...ewriteRemoveUnnecessaryBinaryOperation.dml | 47 +++++ 3 files changed, 266 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperationTest.java new file mode 100644 index 00000000000..b505e846aaf --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperationTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRemoveUnnecessaryBinaryOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteRemoveUnnecessaryBinaryOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteRemoveUnnecessaryBinaryOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationDivNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(1, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationDivRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(1, true); // X/1 + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMultRightNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(2, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMultRightRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(2, true); // X*1 + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMultLeftNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(3, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMultLeftRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(3, true); // 1*X + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMinusNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(4, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMinusRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(4, true); // X-0 + } + + @Test + public void testRemoveUnnecessaryBinaryOperationNegMultLeftNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(5, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationNegMultLeftRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(5, true); // -1*X + } + + @Test + public void testRemoveUnnecessaryBinaryOperationNegMultRightNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(6, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationNegMultRightRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(6, true); // X*-1 + } + + private void testRewriteRemoveUnnecessaryBinaryOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(ID == 1) { + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.DIV.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.DIV.toString())); + } + else if(ID == 2 || ID == 3) { + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.MULT.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + } + else if(ID == 4) { + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.MINUS.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MINUS.toString())); + } + else if(ID == 5 || ID == 6) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.MINUS.toString()) && + !heavyHittersContainsString(Opcodes.MULT.toString())); + else + Assert.assertTrue(!heavyHittersContainsString(Opcodes.MINUS.toString()) && + heavyHittersContainsString(Opcodes.MULT.toString())); + } + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.R b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.R new file mode 100644 index 00000000000..79ffbe6072f --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.R @@ -0,0 +1,51 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix and operation type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +type = as.integer(args[2]) + +# Perform operations +if(type==1){ + R = X/1 +} else if(type==2){ + R = X*1 +} else if(type==3){ + R = 1*X +} else if(type==4){ + R = X-0 +} else if(type==5){ + R = -1*X +} else if(type==6){ + R = X * -1 +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.dml b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.dml new file mode 100644 index 00000000000..6ffe4e543ac --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.dml @@ -0,0 +1,47 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X and operation type +X = read($1) +type = $2 + +# Perform operations +if(type==1){ + R = X/1 +} +else if(type==2){ + R = X*1 +} +else if(type==3){ + R = 1*X +} +else if(type==4){ + R = X-0 +} +else if(type==5){ + R = -1*X +} +else if(type==6){ + R = X * -1 +} + +# Write the result matrix R +write(R, $3) From 80c4f894184baf73b35e670c2bcde5cf57432bc6 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Mon, 24 Feb 2025 17:14:42 +0100 Subject: [PATCH 03/16] added testing for simplifyBinaryToUnaryOperation --- ...iteSimplifyBinaryToUnaryOperationTest.java | 132 ++++++++++++++++++ .../RewriteSimplifyBinaryToUnaryOperation.R | 45 ++++++ .../RewriteSimplifyBinaryToUnaryOperation.dml | 38 +++++ 3 files changed, 215 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryToUnaryOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryToUnaryOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryToUnaryOperationTest.java new file mode 100644 index 00000000000..f91afbbc464 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryToUnaryOperationTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyBinaryToUnaryOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyBinaryToUnaryOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyBinaryToUnaryOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyBinaryToUnaryOperationAddNoRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(1, false); + } + + @Test + public void testSimplifyBinaryToUnaryOperationAddRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(1, true); // X+X -> X*2 + } + + @Test + public void testSimplifyBinaryToUnaryOperationMultNoRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(2, false); + } + + @Test + public void testSimplifyBinaryToUnaryOperationMultRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(2, true); // X*X -> X² + } + + @Test + public void testSimplifyBinaryToUnaryOperationSignNoRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(3, false); + } + + @Test + public void testSimplifyBinaryToUnaryOperationSignRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(3, true); // (X>0)-(X<0) -> sign(X) + } + + private void testRewriteSimplifyBinaryToUnaryOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(ID == 1) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT2.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.PLUS.toString())); + } + else if(ID == 2) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.POW2.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + } + else if(ID == 3) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.SIGN.toString())); + else + Assert.assertTrue(heavyHittersContainsAllString(Opcodes.GREATER.toString(), Opcodes.LESS.toString(), + Opcodes.MINUS.toString())); + } + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.R new file mode 100644 index 00000000000..91398bc3b1e --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix and operation type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +type = as.integer(args[2]) + +# Perform operations +if(type==1){ + R = X+X +} else if(type==2){ + R = X*X +} else if(type==3){ + R = (X>0) - (X<0) +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.dml new file mode 100644 index 00000000000..f60899ff77c --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.dml @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X and operation type +X = read($1) +type = $2 + +# Perform operations +if(type==1){ + R = X+X # X+X -> X*2 +} +else if(type==2){ + R = X*X # X*X -> X² +} +else if(type==3){ + R = (X>0)-(X<0) # (X>0)-(X<0) -> sign(X) +} + +# Write the result matrix R +write(R, $3) From 6d2050df7abf289e9bdccf7b259407425db7f208 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Tue, 25 Feb 2025 17:06:14 +0100 Subject: [PATCH 04/16] added testing for canonicalizeMatrixMultScalarAdd --- ...teCanonicalizeMatrixMultScalarAddTest.java | 117 ++++++++++++++++++ .../RewriteCanonicalizeMatrixMultScalarAdd.R | 46 +++++++ ...RewriteCanonicalizeMatrixMultScalarAdd.dml | 37 ++++++ 3 files changed, 200 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAddTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.R create mode 100644 src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAddTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAddTest.java new file mode 100644 index 00000000000..c9a580d5e30 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAddTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteCanonicalizeMatrixMultScalarAddTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteCanonicalizeMatrixMultScalarAdd"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteCanonicalizeMatrixMultScalarAddTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testCanonicalizeMatrixMultScalarAddPosNoRewrite() { + testRewriteCanonicalizeMatrixMultScalarAdd(1, false); + } + + @Test + public void testCanonicalizeMatrixMultScalarAddPosRewrite() { + testRewriteCanonicalizeMatrixMultScalarAdd(1, true); // (z + U%*%V) -> (U%*%V + z) + } + + @Test + public void testCanonicalizeMatrixMultScalarAddNegNoRewrite() { + testRewriteCanonicalizeMatrixMultScalarAdd(2, false); + } + + @Test + public void testCanonicalizeMatrixMultScalarAddNegRewrite() { + testRewriteCanonicalizeMatrixMultScalarAdd(2, true); // (U%*%V - z) -> (U%*%V + (-z)) + } + + private void testRewriteCanonicalizeMatrixMultScalarAdd(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("U"), input("V"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrices + double[][] U = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + double[][] V = getRandomMatrix(rows, cols, -1, 1, 0.60d, 4); + writeInputMatrixWithMTD("U", U, true); + writeInputMatrixWithMTD("V", V, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(ID == 1) { + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.MULT.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + } + else if(ID == 2) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.PLUS.toString())); + else + Assert.assertFalse(heavyHittersContainsString(Opcodes.PLUS.toString())); + } + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.R b/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.R new file mode 100644 index 00000000000..447f2ef7475 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.R @@ -0,0 +1,46 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrices U, V from Matrix Market format files +U = as.matrix(readMM(paste(args[1], "U.mtx", sep=""))) +V = as.matrix(readMM(paste(args[1], "V.mtx", sep=""))) +type = as.integer(args[2]) +eps = 0.5 + +# Perform the operations +if( type == 1 ) { + R = (eps + U%*%V) +} else if( type == 2 ) { + R = (U%*%V - eps) +} + + +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.dml b/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.dml new file mode 100644 index 00000000000..fd8a1db7dee --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrices U, V, and operation type +U = read($1) +V = read($2) +type = $3 +eps = 0.5 + +# Perform operations +if(type==1){ + R = (eps + U%*%V)*1 # (eps + U%*%V) -> (U%*%V + eps) +} +else if(type==2){ + R = (U%*%V - eps) # (U%*%V - eps) -> (U%*%V + (-eps)) +} + +# Write the result matrix R +write(R, $4) From ef5d7293102542d044643947fa89bf6f4ea0a64f Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Tue, 25 Feb 2025 17:42:26 +0100 Subject: [PATCH 05/16] added testing for simplifyReverseOperation --- .../RewriteSimplifyReverseOperationTest.java | 100 ++++++++++++++++++ .../rewrite/RewriteSimplifyReverseOperation.R | 38 +++++++ .../RewriteSimplifyReverseOperation.dml | 29 +++++ 3 files changed, 167 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseOperationTest.java new file mode 100644 index 00000000000..ed2f53bc514 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseOperationTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyReverseOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyReverseOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyReverseOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyReverseOperationNoRewrite() { + testRewriteSimplifyReverseOperation(false); + } + + @Test + public void testSimplifyReverseOperationRewrite() { + testRewriteSimplifyReverseOperation(true); + } + + private void testRewriteSimplifyReverseOperation(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.REV.toString()) && + !heavyHittersContainsAllString(Opcodes.MMULT.toString(), Opcodes.SEQUENCE.toString(), + Opcodes.CTABLEEXPAND.toString())); + else + Assert.assertTrue(heavyHittersContainsAllString(Opcodes.MMULT.toString(), Opcodes.SEQUENCE.toString(), + Opcodes.CTABLEEXPAND.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.R new file mode 100644 index 00000000000..1bd054186a6 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.R @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix X from Matrix Market format files +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +# Perform the operation +R = table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.dml new file mode 100644 index 00000000000..cb3004ff9de --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X +X = read($1) + +# Perform operation +R = table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X # Rewrite -> rev(X) + +# Write the result matrix R +write(R, $2) From 8d75d762793616ab39ba9b779c0f10bf707c0506 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Tue, 25 Feb 2025 19:45:16 +0100 Subject: [PATCH 06/16] added testing for simplifyMultiBinaryToBinaryOperation --- ...plifyMultiBinaryToBinaryOperationTest.java | 100 ++++++++++++++++++ ...riteSimplifyMultiBinaryToBinaryOperation.R | 39 +++++++ ...teSimplifyMultiBinaryToBinaryOperation.dml | 30 ++++++ 3 files changed, 169 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperationTest.java new file mode 100644 index 00000000000..395fcf3d3a9 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperationTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyMultiBinaryToBinaryOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyMultiBinaryToBinaryOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyMultiBinaryToBinaryOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyMultiBinaryToBinaryOperationNoRewrite() { + testRewriteSimplifyMultiBinaryToBinaryOperation(false); + } + + @Test + public void testSimplifyMultiBinaryToBinaryOperationRewrite() { + testRewriteSimplifyMultiBinaryToBinaryOperation(true); + } + + private void testRewriteSimplifyMultiBinaryToBinaryOperation(boolean rewrites) { + boolean oldFlag1 = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + boolean oldFlag2 = OptimizerUtils.ALLOW_OPERATOR_FUSION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), input("Y"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites; + + // create and write matrices + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.60d, 4); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.MINUS1_MULT.toString())); + else + Assert.assertTrue(heavyHittersContainsAllString(Opcodes.MINUS.toString(), Opcodes.MULT.toString())); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag1; + OptimizerUtils.ALLOW_OPERATOR_FUSION = oldFlag2; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.R new file mode 100644 index 00000000000..c366852af71 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.R @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrices X, Y from Matrix Market format files +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) + +# Perform the operation +R = 1-(X*Y) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.dml new file mode 100644 index 00000000000..2ec9e0c5bb7 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrices X, Y +X = read($1) +Y = read($2) + +# Perform operation +R = 1-(X*Y) + +# Write the result matrix R +write(R, $3) From bbbe0f561d89db17c92d964a014644fb112e1150 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Tue, 25 Feb 2025 20:27:21 +0100 Subject: [PATCH 07/16] added testing for simplifyUnaryAggReorgOperation --- ...iteSimplifyUnaryAggReorgOperationTest.java | 95 +++++++++++++++++++ .../RewriteSimplifyUnaryAggReorgOperation.R | 38 ++++++++ .../RewriteSimplifyUnaryAggReorgOperation.dml | 29 ++++++ 3 files changed, 162 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryAggReorgOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryAggReorgOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryAggReorgOperationTest.java new file mode 100644 index 00000000000..8c58a21a155 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryAggReorgOperationTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyUnaryAggReorgOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyUnaryAggReorgOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyUnaryAggReorgOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyUnaryAggReorgOperationNoRewrite() { + testRewriteSimplifyUnaryAggReorgOperation(false); + } + + @Test + public void testSimplifyUnaryAggReorgOperationRewrite() { + testRewriteSimplifyUnaryAggReorgOperation(true); // sum(t(X)) -> sum(X) + } + + private void testRewriteSimplifyUnaryAggReorgOperation(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLScalarFromOutputDir("R"); + HashMap rfile = readRScalarFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.TRANSPOSE.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.TRANSPOSE.toString())); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.R new file mode 100644 index 00000000000..0faf8769513 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.R @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix X from Matrix Market format files +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +# Perform the operation +R = sum(t(X)) + +write(R, paste(args[2], "R" ,sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.dml new file mode 100644 index 00000000000..abff755d735 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X +X = read($1) + +# Perform operation +R = sum(t(X)) + +# Write the result matrix R +write(R, $2) From e95c5b5a7bc7cc01684497a2115d57c8d4b62028 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Wed, 26 Feb 2025 14:46:42 +0100 Subject: [PATCH 08/16] added testing for simplifyBinaryMatrixScalarOperation --- ...mplifyBinaryMatrixScalarOperationTest.java | 125 ++++++++++++++++++ ...writeSimplifyBinaryMatrixScalarOperation.R | 50 +++++++ ...iteSimplifyBinaryMatrixScalarOperation.dml | 42 ++++++ 3 files changed, 217 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperationTest.java new file mode 100644 index 00000000000..8772ce923c0 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperationTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyBinaryMatrixScalarOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyBinaryMatrixScalarOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyBinaryMatrixScalarOperationTest.class.getSimpleName() + "/"; + + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationMMNoRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(1, false); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationMMRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(1, true); //as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y) + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationRightNoRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(2, false); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationRightRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(2, true); // as.scalar(X*s) -> as.scalar(X) * s + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationLeftNoRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(3, false); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationLeftRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(3, true); // as.scalar(s*X) -> s * as.scalar(X) + } + + private void testRewriteSimplifyBinaryMatrixScalarOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLScalarFromOutputDir("R"); + HashMap rfile = readRScalarFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + long numCastDts = Statistics.getCPHeavyHitterCount(Opcodes.CAST_AS_SCALAR.toString()); + if(ID == 1) { + if(rewrites) + Assert.assertEquals(2, numCastDts); + else + Assert.assertEquals(1, numCastDts); + } + else if(ID == 2) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT2.toString())); + } + else if(ID == 3) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.NM.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + } + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.R new file mode 100644 index 00000000000..2aa799dd180 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.R @@ -0,0 +1,50 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read operation type +type = as.integer(args[2]) + +# Create variables +X = matrix(1,1,1) +Y = matrix(2,1,1) +s = 2 + +# Perform the operations +if(type==1){ + R = as.numeric(X*Y) +} else if(type==2){ + R = as.numeric(X*s) +} else if(type==3){ + R = as.numeric(s*X) +} + + +write(R, paste(args[3], "R" ,sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.dml new file mode 100644 index 00000000000..18da3c4f4f3 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.dml @@ -0,0 +1,42 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read operation type +type = $1 + +# Create variables +X = matrix(1,1,1) +Y = matrix(2,1,1) +s = 2 + +# Perform operations +if(type==1){ + R = as.scalar(X*Y) # -> as.scalar(X) * as.scalar(Y) +} +else if(type==2){ + R = as.scalar(X*s) # -> as.scalar(X) * s +} +else if(type==3){ + R = as.scalar(s*X)*1 # -> s * as.scalar(X) +} + +# Write the result +write(R, $2) From 0d4254552b55515b5b61836793a241213138aacb Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Thu, 27 Feb 2025 00:07:00 +0100 Subject: [PATCH 09/16] added testing for simplifyUnaryPPredOperation --- ...ewriteSimplifyUnaryPPredOperationTest.java | 409 ++++++++++++++++++ .../RewriteSimplifyUnaryPPredOperation.R | 101 +++++ .../RewriteSimplifyUnaryPPredOperation.dml | 127 ++++++ 3 files changed, 637 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryPPredOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryPPredOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryPPredOperationTest.java new file mode 100644 index 00000000000..940ec77f8c4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryPPredOperationTest.java @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyUnaryPPredOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyUnaryPPredOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyUnaryPPredOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + /** + * (1) Rewrites for Less + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationLessAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(1, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(1, true); // abs(X (X (X (X (X (X (X<=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(7, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(7, true); // round(X<=Y) -> (X<=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(8, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(8, true); // ceil(X<=Y) -> (X<=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(9, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(9, true); // floor(X<=Y) -> (X<=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(10, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(10, true); // sign(X<=Y) -> (X<=Y) + } + + /** + * (3) Rewrites for Greater + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationGreaterAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(11, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(11, true); // abs(X>Y) -> (X>Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(12, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(12, true); // round(X>Y) -> (X>Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(13, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(13, true); // ceil(X>Y) -> (X>Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(14, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(14, true); // floor(X>Y) -> (X>Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(15, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(15, true); // sign(X>Y) -> (X>Y) + } + + /** + * (4) Rewrites for GreaterEqual + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(16, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(16, true); // abs(X>=Y) -> (X>=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(17, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(17, true); // round(X>=Y) -> (X>=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(18, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(18, true); // ceil(X>=Y) -> (X>=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(19, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(19, true); // floor(X>=Y) -> (X>=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(20, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(20, true); // sign(X>=Y) -> (X>=Y) + } + + /** + * (5) Rewrites for Equal + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationEqualAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(21, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(21, true); // abs(X==Y) -> (X==Y) + } + + @Test + public void testSimplifyUnaryPPredOperationEqualRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(22, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(22, true); // round(X==Y) -> (X==Y) + } + + @Test + public void testSimplifyUnaryPPredOperationEqualCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(23, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(23, true); // ceil(X==Y) -> (X==Y) + } + + @Test + public void testSimplifyUnaryPPredOperationEqualFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(24, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(24, true); // floor(X==Y) -> (X==Y) + } + + @Test + public void testSimplifyUnaryPPredOperationEqualSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(25, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(25, true); // sign(X==Y) -> (X==Y) + } + + /** + * (6) Rewrites for NotEqual + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationNotEqualAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(26, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(26, true); // abs(X!=Y) -> (X!=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(27, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(27, true); // round(X!=Y) -> (X!=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(28, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(28, true); // ceil(X!=Y) -> (X!=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(29, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(29, true); // floor(X!=Y) -> (X!=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(30, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(30, true); // sign(X!=Y) -> (X!=Y) + } + + private void testRewriteSimplifyUnaryPPredOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), input("Y"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrices + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.60d, 3); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.ABS.toString(), Opcodes.ROUND.toString(), + Opcodes.CEIL.toString(), Opcodes.FLOOR.toString(), Opcodes.SIGN.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.ABS.toString(), Opcodes.ROUND.toString(), + Opcodes.CEIL.toString(), Opcodes.FLOOR.toString(), Opcodes.SIGN.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.R new file mode 100644 index 00000000000..f7b27a5db11 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.R @@ -0,0 +1,101 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrices and operation type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) +type = as.integer(args[2]) + +# Perform operations +# (1) Less +if(type==1){ + R = abs(XY) +} else if(type==12){ + R = round(X>Y) +} else if(type==13){ + R = ceiling(X>Y) +} else if(type==14){ + R = floor(X>Y) +} else if(type==15){ + R = sign(X>Y) +} else if(type==16){ # (4) Greater-Equal + R = abs(X>=Y) +} else if(type==17){ + R = round(X>=Y) +} else if(type==18){ + R = ceiling(X>=Y) +} else if(type==19){ + R = floor(X>=Y) +} else if(type==20){ + R = sign(X>=Y) +} else if(type==21){ # (5) Equal + R = abs(X==Y) +} else if(type==22){ + R = round(X==Y) +} else if(type==23){ + R = ceiling(X==Y) +} else if(type==24){ + R = floor(X==Y) +} else if(type==25){ + R = sign(X==Y) +} else if(type==26){ # (6) Not-Equal + R = abs(X!=Y) +} else if(type==27){ + R = round(X!=Y) +} else if(type==28){ + R = ceiling(X!=Y) +} else if(type==29){ + R = floor(X!=Y) +} else if(type==30){ + R = sign(X!=Y) +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.dml new file mode 100644 index 00000000000..e016bc65e3f --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.dml @@ -0,0 +1,127 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrices X, Y, and operation type +X = read($1) +Y = read($2) +type = $3 + +# Perform operations +# (1) Less +if(type==1){ + R = abs(XY) +} +else if(type==12){ + R = round(X>Y) +} +else if(type==13){ + R = ceil(X>Y) +} +else if(type==14){ + R = floor(X>Y) +} +else if(type==15){ + R = sign(X>Y) +} +# (4) Greater-Equal +else if(type==16){ + R = abs(X>=Y) +} +else if(type==17){ + R = round(X>=Y) +} +else if(type==18){ + R = ceil(X>=Y) +} +else if(type==19){ + R = floor(X>=Y) +} +else if(type==20){ + R = sign(X>=Y) +} +# (5) Equal +else if(type==21){ + R = abs(X==Y) +} +else if(type==22){ + R = round(X==Y) +} +else if(type==23){ + R = ceil(X==Y) +} +else if(type==24){ + R = floor(X==Y) +} +else if(type==25){ + R = sign(X==Y) +} +# (6) Not-Equal +else if(type==26){ + R = abs(X!=Y) +} +else if(type==27){ + R = round(X!=Y) +} +else if(type==28){ + R = ceil(X!=Y) +} +else if(type==29){ + R = floor(X!=Y) +} +else if(type==30){ + R = sign(X!=Y) +} + + +# Write the result matrix R +write(R, $4) From 1fc7a67630f17554bdbc5141ddd6e454390974a0 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Thu, 27 Feb 2025 15:39:35 +0100 Subject: [PATCH 10/16] added testing for simplifyTransposedAppend --- .../RewriteSimplifyTransposedAppendTest.java | 108 ++++++++++++++++++ .../rewrite/RewriteSimplifyTransposedAppend.R | 45 ++++++++ .../RewriteSimplifyTransposedAppend.dml | 36 ++++++ 3 files changed, 189 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedAppendTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedAppendTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedAppendTest.java new file mode 100644 index 00000000000..2c6721bcec6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedAppendTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyTransposedAppendTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyTransposedAppend"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyTransposedAppendTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyTransposedAppendTransposeCBindNoRewrite() { + testRewriteSimplifyTransposedAppend(1, false); + } + + @Test + public void testSimplifyTransposedAppendTransposeCBindRewrite() { + testRewriteSimplifyTransposedAppend(1, true); // t(cbind(t(A),t(B))) --> rbind(A,B) + } + + @Test + public void testSimplifyTransposedAppendTransposeRBindNoRewrite() { + testRewriteSimplifyTransposedAppend(2, false); + } + + @Test + public void testSimplifyTransposedAppendTransposeRBindRewrite() { + testRewriteSimplifyTransposedAppend(2, true); // t(rbind(t(A),t(B))) --> cbind(A,B) + } + + private void testRewriteSimplifyTransposedAppend(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("A"), input("B"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrices + double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + double[][] B = getRandomMatrix(rows, cols, -1, 1, 0.60d, 3); + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("B", B, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.TRANSPOSE.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.TRANSPOSE.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.R b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.R new file mode 100644 index 00000000000..c0c59dcfc0a --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrices and operation type +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = as.matrix(readMM(paste(args[1], "B.mtx", sep=""))) +type = as.integer(args[2]) + + +# Perform operations +if(type==1){ + R = t(cbind(t(A),t(B))) +} else if(type==2) { + R = t(rbind(t(A),t(B))) +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.dml new file mode 100644 index 00000000000..754613071b9 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrices A, B, and operation type +A = read($1) +B = read($2) +type = $3 + +# Perform operations +if(type==1){ + R = t(cbind(t(A),t(B))) # -> rbind(A, B) +} +else if(type==2) { + R = t(rbind(t(A),t(B))) # -> cbind(A, B) +} + +# Write the result matrix R +write(R, $4) From d741badef1896eeedda3a02bf04323104151e8a2 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Thu, 27 Feb 2025 17:06:01 +0100 Subject: [PATCH 11/16] added testing for fuseOrderOperationChain --- .../RewriteFuseOrderOperationChainTest.java | 98 +++++++++++++++++++ .../rewrite/RewriteFuseOrderOperationChain.R | 43 ++++++++ .../RewriteFuseOrderOperationChain.dml | 29 ++++++ 3 files changed, 170 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseOrderOperationChainTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.R create mode 100644 src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseOrderOperationChainTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseOrderOperationChainTest.java new file mode 100644 index 00000000000..281fc75cd30 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseOrderOperationChainTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteFuseOrderOperationChainTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteFuseOrderOperationChain"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteFuseOrderOperationChainTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testFuseOrderOperationChainNoRewrite() { + testRewriteFuseOrderOperationChain(false); + } + + @Test + public void testFuseOrderOperationChainRewrite() { + testRewriteFuseOrderOperationChain(true); + } + + private void testRewriteFuseOrderOperationChain(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrices + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + long numOrder = Statistics.getCPHeavyHitterCount(Opcodes.SORT.toString()); + if(rewrites) + Assert.assertEquals(numOrder, 1); + else + Assert.assertEquals(numOrder, 2); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.R b/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.R new file mode 100644 index 00000000000..b8c4e429860 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.R @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +# Perform operation +# 1) Sort X by column 2 +temp = X[order(X[, 2]), ] + +# 2) Sort the result by column 1 +R = temp[order(temp[, 1]), ] + + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.dml b/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.dml new file mode 100644 index 00000000000..746b3a796a0 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix +X = read($1) + +# Perform operation +R = order(target=order(target=X, by=2), by=1) + +# Write the result matrix R +write(R, $2) From 0117515cb1a5ab4d6a4bb41edb5fa1d283403064 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Thu, 27 Feb 2025 20:54:13 +0100 Subject: [PATCH 12/16] added testing for removeUnnecessaryReorgOperation --- .../org/apache/sysds/hops/OptimizerUtils.java | 2 +- ...teRemoveUnnecessaryReorgOperationTest.java | 107 ++++++++++++++++++ .../RewriteRemoveUnnecessaryReorgOperation.R | 44 +++++++ ...RewriteRemoveUnnecessaryReorgOperation.dml | 35 ++++++ 4 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.R create mode 100644 src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.dml diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index a3161c57230..30df23dba2b 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -199,7 +199,7 @@ public enum MemoryManager { public static boolean ALLOW_SUM_PRODUCT_REWRITES2 = true; /** - * Enables additional mmchain optimizations. in the future, this might be merged with + * Enables additional mmchain optimizations. In the future, this might be merged with * ALLOW_SUM_PRODUCT_REWRITES. */ public static boolean ALLOW_ADVANCED_MMCHAIN_REWRITES = false; diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java new file mode 100644 index 00000000000..44a476da2c7 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRemoveUnnecessaryReorgOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteRemoveUnnecessaryReorgOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteRemoveUnnecessaryReorgOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testRemoveUnnecessaryReorgOperationTransposeNoRewrite() { + testRewriteRemoveUnnecessaryReorgOperation(1, false); + } + + @Test + public void testRemoveUnnecessaryReorgOperationTransposeRewrite() { + testRewriteRemoveUnnecessaryReorgOperation(1, true); + } + + @Test + public void testRemoveUnnecessaryReorgOperationReverseNoRewrite() { + testRewriteRemoveUnnecessaryReorgOperation(2, false); + } + + @Test + public void testRemoveUnnecessaryReorgOperationReverseRewrite() { + testRewriteRemoveUnnecessaryReorgOperation(2, true); + } + + private void testRewriteRemoveUnnecessaryReorgOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.TRANSPOSE.toString(), Opcodes.REV.toString())); + else + Assert.assertTrue((heavyHittersContainsString(Opcodes.MULT.toString(), Opcodes.REV.toString()))); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.R b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.R new file mode 100644 index 00000000000..0fec789adac --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.R @@ -0,0 +1,44 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix and operation type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +type = as.integer(args[2]) + + +# Perform operations +if(type==1){ + R = t(t(X)) +} else if(type==2) { + R = X[nrow(X):1, ][nrow(X):1, ] + +} +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.dml b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.dml new file mode 100644 index 00000000000..90774083f26 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix and operation type +X = read($1) +type = $2 + +# Perform operations +if(type==1){ + R = t(t(X))*1 # t(t(X)) -> X +} +else if(type==2) { + R = rev(rev(X)) # rev(rev(X)) -> X +} + +# Write the result matrix R +write(R, $3) From cde4a81d7e7207dbde7ffc5a5aa72531f013f1e6 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Thu, 27 Feb 2025 21:19:30 +0100 Subject: [PATCH 13/16] added testing for removeUnnecessaryMinus --- .../RewriteRemoveUnnecessaryMinusTest.java | 96 +++++++++++++++++++ ...teRemoveUnnecessaryReorgOperationTest.java | 1 - .../rewrite/RewriteRemoveUnnecessaryMinus.R | 39 ++++++++ .../rewrite/RewriteRemoveUnnecessaryMinus.dml | 29 ++++++ 4 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryMinusTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.R create mode 100644 src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryMinusTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryMinusTest.java new file mode 100644 index 00000000000..dbd1a31ce42 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryMinusTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRemoveUnnecessaryMinusTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteRemoveUnnecessaryMinus"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteRemoveUnnecessaryMinusTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testRemoveUnnecessaryMinusNoRewrite() { + testRewriteRemoveUnnecessaryMinus(false); + } + + @Test + public void testRemoveUnnecessaryMinusRewrite() { + testRewriteRemoveUnnecessaryMinus(true); + } + + private void testRewriteRemoveUnnecessaryMinus(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.MULT.toString(), Opcodes.POW.toString())); + else + Assert.assertTrue((heavyHittersContainsString(Opcodes.MULT.toString(), Opcodes.POW.toString()))); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java index 44a476da2c7..733c9da7e62 100644 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java @@ -80,7 +80,6 @@ private void testRewriteRemoveUnnecessaryReorgOperation(int ID, boolean rewrites rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; - OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES = rewrites; // create and write matrix double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.R b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.R new file mode 100644 index 00000000000..5c4c14ab136 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.R @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + + +# Perform operation +R = -(-X) + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.dml b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.dml new file mode 100644 index 00000000000..fa2262529e5 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix +X = read($1) + +# Perform operation +R = -(-X) # -(-X) -> X + +# Write the result matrix R +write(R, $2) From 34f21ede022a897c2ab8ba5072c6648d5c43b5d9 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Sat, 1 Mar 2025 23:16:52 +0100 Subject: [PATCH 14/16] added testing for simplifyOuterSeqExpand --- .../RewriteSimplifyOuterSeqExpandTest.java | 104 ++++++++++++++++++ .../rewrite/RewriteSimplifyOuterSeqExpand.R | 45 ++++++++ .../rewrite/RewriteSimplifyOuterSeqExpand.dml | 36 ++++++ 3 files changed, 185 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyOuterSeqExpandTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyOuterSeqExpandTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyOuterSeqExpandTest.java new file mode 100644 index 00000000000..a9c3e4ffe71 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyOuterSeqExpandTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyOuterSeqExpandTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyOuterSeqExpand"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyOuterSeqExpandTest.class.getSimpleName() + "/"; + + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyOuterSeqExpandRightNoRewrite() { + testRewriteSimplifyOuterSeqExpand(1, false); + } + + // outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) + @Test + public void testSimplifyOuterSeqExpandRightRewrite() { + testRewriteSimplifyOuterSeqExpand(1, true); + } + + @Test + public void testSimplifyOuterSeqExpandLeftNoRewrite() { + testRewriteSimplifyOuterSeqExpand(2, false); + } + + // outer(seq(1,m), t(v), "==") -> rexpand(m, max=v, dir=row, ignore=true, cast=false) + @Test + public void testSimplifyOuterSeqExpandLeftRewrite() { + testRewriteSimplifyOuterSeqExpand(2, true); + } + + private void testRewriteSimplifyOuterSeqExpand(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.REXPAND.toString())); + else + Assert.assertTrue( + (heavyHittersContainsString(Opcodes.SEQUENCE.toString(), Opcodes.TRANSPOSE.toString()))); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R new file mode 100644 index 00000000000..9c99e16ff6c --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read operation type +type = as.integer(args[2]) +m = 5 +v = matrix(1, 6, 1) + +# Perform operations +if(type==1){ + R = outer(v, 1:m, "==") +} else if(type==2){ + R = outer(1:m, as.vector(v), "==") + +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.dml new file mode 100644 index 00000000000..1e5d0ef9c67 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read operation type +type = $1 +m = 5 +v = matrix(1, 6, 1) + +# Perform operations +if(type==1){ + R = outer(v, t(seq(1,m)), "==") +} +else if(type==2){ + R = outer(seq(1,m), t(v), "==") +} + +# Write the result matrix R +write(R, $2) From 0bcd28ae1150dbc8e8d69da489501b812497107d Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Sun, 2 Mar 2025 00:45:27 +0100 Subject: [PATCH 15/16] added testing for simplifyCumsumColOrFullAggregates --- ...SimplifyCumsumColOrFullAggregatesTest.java | 96 +++++++++++++++++++ ...RewriteSimplifyCumsumColOrFullAggregates.R | 38 ++++++++ ...writeSimplifyCumsumColOrFullAggregates.dml | 29 ++++++ 3 files changed, 163 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregatesTest.java create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.R create mode 100644 src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.dml diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregatesTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregatesTest.java new file mode 100644 index 00000000000..c77db1c148f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregatesTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyCumsumColOrFullAggregatesTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyCumsumColOrFullAggregates"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyCumsumColOrFullAggregatesTest.class.getSimpleName() + "/"; + + private static final int rows = 10; + private static final int cols = 10; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyCumsumColOrFullAggregatesNoRewrite() { + testRewriteSimplifyCumsumColOrFullAggregates(false); + } + + @Test + public void testSimplifyCumsumColOrFullAggregatesRewrite() { + testRewriteSimplifyCumsumColOrFullAggregates(true); //colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) + } + + private void testRewriteSimplifyCumsumColOrFullAggregates(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertTrue((heavyHittersContainsString(Opcodes.SEQUENCE.toString()))); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.UCUMKP.toString(), Opcodes.UACKP.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.R b/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.R new file mode 100644 index 00000000000..e01a6f28085 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.R @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +# Perform operation +R = t(as.matrix(colSums(apply(X, 2, cumsum)))) + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.dml new file mode 100644 index 00000000000..8af853c84dc --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix +X = read($1) + +# Perform operation +R = colSums(cumsum(X)) + +# Write the result matrix R +write(R, $2) From 27f09d8077c54d1ef3c48ba69f9d334bdf6b4867 Mon Sep 17 00:00:00 2001 From: ReneEnjilian Date: Sun, 2 Mar 2025 13:19:12 +0100 Subject: [PATCH 16/16] fix test for simplifyOuterSeqExpand in R file --- .../scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R index 9c99e16ff6c..bec60ac994f 100644 --- a/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R @@ -35,7 +35,7 @@ v = matrix(1, 6, 1) # Perform operations if(type==1){ - R = outer(v, 1:m, "==") + R = outer(as.vector(v), 1:m, "==") } else if(type==2){ R = outer(1:m, as.vector(v), "==")