Skip to content

Commit 63aac6b

Browse files
committed
add test case
1 parent 9c91268 commit 63aac6b

5 files changed

Lines changed: 93 additions & 6 deletions

File tree

src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ public class RowAggTmplTest extends AutomatedTestBase
8989
private static final String TEST_NAME46 = TEST_NAME+"46"; //conv2d(X - mean(X), F1) + conv2d(X - mean(X), F2);
9090
private static final String TEST_NAME47 = TEST_NAME+"47"; //sum(X + rowVars(X))
9191
private static final String TEST_NAME48 = TEST_NAME+"48"; //sum(rowVars(X))
92-
private static final String TEST_NAME49 = TEST_NAME+"49";
92+
private static final String TEST_NAME49 = TEST_NAME+"49"; //X*rowSums(K*v)*X
93+
private static final String TEST_NAME50 = TEST_NAME+"50"; //(abs(A)*B)+(B*v)
9394

9495
private static final String TEST_DIR = "functions/codegen/";
9596
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/";
@@ -101,7 +102,7 @@ public class RowAggTmplTest extends AutomatedTestBase
101102
@Override
102103
public void setUp() {
103104
TestUtils.clearAssertionInformation();
104-
for(int i=1; i<=49; i++)
105+
for(int i=1; i<=50; i++)
105106
addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) );
106107
}
107108

@@ -831,6 +832,9 @@ public void testCodegenRowAgg48SP() {
831832
@Test
832833
public void testCodegenRowAgg49CP() {testCodegenIntegration( TEST_NAME49, false, ExecType.CP );}
833834

835+
@Test
836+
public void testCodegenRowAgg50CP() {testCodegenIntegration( TEST_NAME50, false, ExecType.CP );}
837+
834838
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
835839
{
836840
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

src/test/scripts/functions/codegen/rowAggPattern49.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ K = rbind(Z, Y, Y, Y, Y, Y, Y, Y, Y)
3939

4040
# S = (X < rowSums(X*K))
4141
# S = X*rowMins(K)*X
42-
# S = X*rowSums(K*v)*X
43-
S = (X*v)/rowSums(X*v)
42+
S = X*rowSums(K*v)*X
43+
# S = (X*v)/rowSums(X*v)
4444
# S = abs((X*v)/rowSums(X*v))
4545
# S = (X/v)+rowMeans(X-v)
4646
# S = (X*v)+rowSums(X*v)

src/test/scripts/functions/codegen/rowAggPattern49.dml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ while(FALSE) { }
3838
# S = B * rowSums(v)
3939
# S = (X < rowSums(X*K))
4040
# S = X*(k>1)*X
41-
# S = X*rowSums(K*v)*X
42-
S = (X*v)/rowSums(X*v)
41+
S = X*rowSums(K*v)*X
42+
# S = (X*v)/rowSums(X*v)
4343
# S = abs((X*v)/rowSums(X*v))
4444
# S = (X/v)+rowMeans(X-v)
4545
# S = (X*v)+rowSums(X*v)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args<-commandArgs(TRUE)
23+
options(digits=22)
24+
library("Matrix")
25+
# library("matrixStats")
26+
27+
W = matrix(seq(28,29), 1, 2)
28+
J = matrix(0, 1, 8)
29+
Z = cbind(J, W, J)
30+
Y = matrix(0, 10, 18)
31+
X = rbind(Z, Y, Y, Y, Y, Y, Y, Y, Y)
32+
v = seq(1,81)
33+
v1 = seq(20, 37)
34+
W = matrix(seq(13,14), 1, 2)
35+
J = matrix(0, 1, 8)
36+
Z= cbind(J, W, J)
37+
Y = matrix(0, 10, 18)
38+
K = rbind(Z, Y, Y, Y, Y, Y, Y, Y, Y)
39+
40+
S = (abs(X)*K)+rowSums(X*v)
41+
42+
43+
writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep=""));
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
W = matrix(seq(28,29), 1, 2)
23+
J = matrix(0, 1, 8)
24+
Z= cbind(J, W, J)
25+
Y = matrix(0, 10, 18)
26+
X = rbind(Z, Y, Y, Y, Y, Y, Y, Y, Y)
27+
v = seq(1,81)
28+
v1 = seq(20, 37)
29+
W = matrix(seq(13,14), 1, 2)
30+
J = matrix(0, 1, 8)
31+
Z= cbind(J, W, J)
32+
Y = matrix(0, 10, 18)
33+
K = rbind(Z, Y, Y, Y, Y, Y, Y, Y, Y)
34+
while(FALSE) { }
35+
36+
37+
S = (abs(X)*K)+rowSums(X*v)
38+
39+
40+
write(S,$1)

0 commit comments

Comments
 (0)