diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java index 8d0dc273eb3..c909d86cd18 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java @@ -66,7 +66,7 @@ public class TemplateCell extends TemplateBase { private static final AggOp[] SUPPORTED_AGG = - new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX}; + new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD}; public TemplateCell() { super(TemplateType.CELL); diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java index 4ac38b1487c..c42ea6c8580 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java @@ -67,7 +67,7 @@ public class TemplateRow extends TemplateBase { - private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN}; + private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD}; private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{ OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN, OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH, diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java index 6c3986bd93c..d59f50a0e85 100644 --- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java @@ -84,6 +84,7 @@ public enum AggOp { SUM_SQ, MIN, MAX, + PROD } protected final CellType _type; @@ -332,12 +333,16 @@ private long executeDense(DenseBlock a, SideInput[] b, double[] scalars, else if( _type == CellType.ROW_AGG ) { if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ ) return executeDenseRowAggSum(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix); + else if(_aggOp == AggOp.PROD) + return executeDenseRowProd(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix); else return executeDenseRowAggMxx(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix); } else if( _type == CellType.COL_AGG ) { if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ ) return executeDenseColAggSum(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix); + else if(_aggOp == AggOp.PROD) + return executeDenseColProd(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix); else return executeDenseColAggMxx(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix); } @@ -372,12 +377,16 @@ private long executeSparse(SparseBlock sblock, SideInput[] b, double[] scalars, else if( _type == CellType.ROW_AGG ) { if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ ) return executeSparseRowAggSum(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix); + else if( _aggOp == AggOp.PROD) + return executeSparseRowProd(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix); else return executeSparseRowAggMxx(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix); } else if( _type == CellType.COL_AGG ) { if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ ) return executeSparseColAggSum(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix); + else if( _aggOp == AggOp.PROD) + return executeSparseColProd(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix); else return executeSparseColAggMxx(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix); } @@ -930,7 +939,215 @@ private double executeSparseAggMxx(SparseBlock sblock, SideInput[] b, double[] s } return ret; } - + + private long executeDenseRowProd(DenseBlock a, SideInput[] b, double[] scalars, + DenseBlock c, int m, int n, boolean sparseSafe, int rl, int ru, long rix) + { + // single block output + double[] lc = c.valuesAt(0); + long lnnz = 0; + if(a == null && !sparseSafe) { + for(int i = rl; i < ru; i++) { + for(int j = 0; j < n; j++) { + if(j == 0) { + lc[i] = genexec(0, b, scalars, m, n, rix+i, i, j); + } else if(lc[i] != 0) { + lc[i] *= genexec(0, b, scalars, m, n, rix+i, i, j); + } else { + break; + } + } + lnnz += (lc[i]!=0) ? 1 : 0; + } + } + else if( a != null ) { + for(int i = rl; i < ru; i++) { + double[] avals = a.values(i); + int aix = a.pos(i); + for(int j = 0; j < n; j++) { + double aval = avals[aix + j]; + if(aval != 0 || !sparseSafe) { + if(j == 0) { + lc[i] = genexec(aval, b, scalars, m, n, rix+i, i, j); + } else if(lc[i] != 0) { + lc[i] *= genexec(aval, b, scalars, m, n, rix+i, i, j); + } else { + break; + } + } else { + break; + } + } + lnnz += (lc[i] != 0) ? 1 : 0; + } + } + return lnnz; + } + + private long executeDenseColProd(DenseBlock a, SideInput[] b, double[] scalars, + DenseBlock c, int m, int n, boolean sparseSafe, int rl, int ru, long rix) + { + double[] lc = c.valuesAt(0); + //track the cols that have a zero + boolean[] zeroFlag = new boolean[n]; + if(a == null && !sparseSafe) { + for(int i = rl; i < ru; i++) { + for(int j = 0; j < n; j++) { + if(!zeroFlag[j]) { + if(i == 0) { + lc[j] = genexec(0, b, scalars, m, n, rix+i, i, j); + } else if(lc[j] != 0) { + lc[j] *= genexec(0, b, scalars, m, n, rix+i, i, j); + } else { + zeroFlag[j] = true; + } + } + } + } + } + else if(a != null) { + for(int i = rl; i < ru; i++) { + double[] avals = a.values(i); + int aix = a.pos(i); + for(int j = 0; j < n; j++) { + if(!zeroFlag[j]) { + double aval = avals[aix + j]; + if(aval != 0 || !sparseSafe) { + if(i == 0) { + lc[j] = genexec(aval, b, scalars, m, n, rix + i, i, j); + } else if(lc[j] != 0) { + lc[j] *= genexec(aval, b, scalars, m, n, rix + i, i, j); + } else { + zeroFlag[j] = true; + } + } + } else { + zeroFlag[j] = true; + } + } + } + } + return -1; + } + + private long executeSparseRowProd(SparseBlock sblock, SideInput[] b, double[] scalars, + MatrixBlock out, int m, int n, boolean sparseSafe, int rl, int ru, long rix) + { + double[] c = out.getDenseBlockValues(); + long lnnz = 0; + for(int i = rl; i < ru; i++) { + int lastj = -1; + if(sblock != null && !sblock.isEmpty(i)) { + int apos = sblock.pos(i); + int alen = sblock.size(i); + int[] aix = sblock.indexes(i); + double[] avals = sblock.values(i); + for(int k = apos; k < apos+alen; k++) { + if(!sparseSafe) { + for(int j=lastj+1; j