From db8820db5eae319ba6a9b54dd7ce2e9a989a9b53 Mon Sep 17 00:00:00 2001 From: e-strauss Date: Thu, 6 Feb 2025 14:40:52 +0100 Subject: [PATCH] [SYSTEMDS-3541] Exploratory workload-aware compression on intermediates Added a config option for aggressive compression and extended the compression workload analyzer to detect aggregation operations and binary matrix-vector operations when inputs are compressed as a single column group. Updated cost estimation for compression on already compressed inputs and removed scalars from compressible intermediate candidates. Added support for double compressed binary matrix-matrix operations and implemented both single-threaded and multithreaded compressed binary matrix-vector operations with single column group encoding. Removed the relaxed compression threshold and added a logging statement for potential improvements in compressed binary matrix-vector operations. Enabled always sampling for binary matrix-vector operations in CLALibBinaryCellOp, expanded test coverage, and introduced a new compression algorithm test case for k-means with intermediate compression enabled. I also extended the CLALibBinaryCellOp binary matrix-vector (sparse & dense) op task to support left and right operations. --- .../java/org/apache/sysds/conf/DMLConfig.java | 1 + .../rewrite/RewriteCompressedReblock.java | 3 +- .../CompressedMatrixBlockFactory.java | 3 +- .../compress/lib/CLALibBinaryCellOp.java | 460 ++++++++++++++---- .../compress/workload/WorkloadAnalyzer.java | 76 ++- .../frame/data/columns/HashMapToInt.java | 43 ++ .../lib/CLALibBinaryCellOpCustomTest.java | 33 ++ .../compress/lib/CLALibBinaryCellOpTest.java | 103 ++-- .../workload/WorkloadAlgorithmTest.java | 10 + 9 files changed, 605 insertions(+), 127 deletions(-) diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java index fd34fa4439c..e1b7b0bb530 100644 --- a/src/main/java/org/apache/sysds/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java @@ -79,6 +79,7 @@ public class DMLConfig public static final String PARALLEL_TOKENIZE = "sysds.parallel.tokenize"; public static final String PARALLEL_TOKENIZE_NUM_BLOCKS = "sysds.parallel.tokenize.numBlocks"; public static final String COMPRESSED_LINALG = "sysds.compressed.linalg"; + public static final String COMPRESSED_LINALG_INTERMEDIATE = "sysds.compressed.linalg.intermediate"; public static final String COMPRESSED_LOSSY = "sysds.compressed.lossy"; public static final String COMPRESSED_VALID_COMPRESSIONS = "sysds.compressed.valid.compressions"; public static final String COMPRESSED_OVERLAPPING = "sysds.compressed.overlapping"; diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java index ec917b01458..8fc568ca6c0 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java @@ -171,11 +171,12 @@ public static boolean satisfiesAggressiveCompressionCondition(Hop hop) { satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) && hop.getInput(0).getDataType().isMatrix() && hop.getInput(1).getDataType().isMatrix(); - satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD) && !hop.isScalar(); + satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD); satisfies |= HopRewriteUtils.isUnary(hop, OpOp1.ROUND, OpOp1.FLOOR, OpOp1.NOT, OpOp1.CEIL); satisfies |= HopRewriteUtils.isBinary(hop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.AND, OpOp2.OR, OpOp2.MODULUS); satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE); + satisfies &= !hop.isScalar(); } if(LOG.isDebugEnabled() && satisfies) LOG.debug("Operation Satisfies: " + hop); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 93240644a14..7ea2cc39663 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -344,7 +344,8 @@ private void classifyPhase() { // final int nRows = mb.getNumRows(); final int nCols = mb.getNumColumns(); // Assume the scaling of cocoding is at maximum square root good relative to number of columns. - final double scale = Math.sqrt(nCols); + final double scale = mb instanceof CompressedMatrixBlock && + ((CompressedMatrixBlock) mb).getColGroups().size() == 1 ? 1 : Math.sqrt(nCols); final double threshold = _stats.estimatedCostCols / scale; if(threshold < _stats.originalCost * diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index f16c88080a3..bbc1640a305 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -20,6 +20,7 @@ package org.apache.sysds.runtime.compress.lib; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; @@ -35,13 +36,26 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup; import org.apache.sysds.runtime.compress.colgroup.ASDCZero; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; +import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowScalar; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.frame.data.columns.HashMapToInt; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -52,6 +66,7 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell.BinaryAccessType; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator; import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; @@ -62,6 +77,7 @@ public final class CLALibBinaryCellOp { private static final Log LOG = LogFactory.getLog(CLALibBinaryCellOp.class.getName()); + public static final int DECOMPRESSION_BLEN = 16384; private CLALibBinaryCellOp() { // empty private constructor. @@ -70,7 +86,7 @@ private CLALibBinaryCellOp() { public static MatrixBlock binaryOperationsRight(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that) { try { - op = LibMatrixBincell.replaceOpWithSparseSafeIfApplicable(m1, that, op); + op = LibMatrixBincell.replaceOpWithSparseSafeIfApplicable(m1, that, op); if((that.getNumRows() == 1 && that.getNumColumns() == 1) || that.isEmpty()) { ScalarOperator sop = new RightScalarOperator(op.fn, that.get(0, 0), op.getNumThreads()); @@ -104,16 +120,44 @@ public static MatrixBlock binaryOperationsLeft(BinaryOperator op, CompressedMatr private static MatrixBlock binaryOperationsRightFiltered(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that) throws Exception { BinaryAccessType atype = LibMatrixBincell.getBinaryAccessTypeExtended(m1, that); - if(that instanceof CompressedMatrixBlock && that.getInMemorySize() < m1.getInMemorySize()) { + if(isDoubleCompressedOpApplicable(m1, that)) + return doubleCompressedBinaryOp(op, m1, (CompressedMatrixBlock) that); + if(that instanceof CompressedMatrixBlock && that.getNumColumns() == m1.getNumColumns() + && that.getInMemorySize() < m1.getInMemorySize() ) { MatrixBlock m1uc = CompressedMatrixBlock.getUncompressed(m1, "Decompressing left side in BinaryOps"); return selectProcessingBasedOnAccessType(op, (CompressedMatrixBlock) that, m1uc, atype, true); } else { + // right side has worse compression or is a column vector that = CompressedMatrixBlock.getUncompressed(that, "Decompressing right side in BinaryOps"); return selectProcessingBasedOnAccessType(op, m1, that, atype, false); } } + private static boolean isDoubleCompressedOpApplicable(CompressedMatrixBlock m1, MatrixBlock that) { + return that instanceof CompressedMatrixBlock + && !m1.isOverlapping() + && m1.getColGroups().get(0) instanceof ColGroupDDC + && !((CompressedMatrixBlock) that).isOverlapping() + && ((CompressedMatrixBlock) that).getColGroups().get(0) instanceof ColGroupDDC + && ((IMapToDataGroup) m1.getColGroups().get(0)).getMapToData() == + ((IMapToDataGroup) ((CompressedMatrixBlock) that).getColGroups().get(0)).getMapToData(); + } + + private static CompressedMatrixBlock doubleCompressedBinaryOp(BinaryOperator op, CompressedMatrixBlock m1, CompressedMatrixBlock m2) { + LOG.debug("Double Compressed BinaryOp"); + AColGroup left = m1.getColGroups().get(0); + AColGroup right = m2.getColGroups().get(0); + AMapToData lm = ((IMapToDataGroup) left).getMapToData(); + MatrixBlock lmb = ((ADictBasedColGroup) left).getDictionary().getMBDict(m1.getNumColumns()).getMatrixBlock(); + MatrixBlock rmb = ((ADictBasedColGroup) right).getDictionary().getMBDict(m2.getNumColumns()).getMatrixBlock(); + MatrixBlock out = lmb.binaryOperations(op, rmb); + AColGroup rgroup = ColGroupDDC.create(left.getColIndices(), MatrixBlockDictionary.create(out), lm, null); + CompressedMatrixBlock outCompressed = new CompressedMatrixBlock(m1.getNumRows(), m1.getNumColumns()); + outCompressed.allocateColGroup(rgroup); + return outCompressed; + } + private static MatrixBlock selectProcessingBasedOnAccessType(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that, BinaryAccessType atype, boolean left) throws Exception { @@ -367,13 +411,17 @@ private static MatrixBlock mvColCompressed(CompressedMatrixBlock m1, MatrixBlock final int k = op.getNumThreads(); long nnz = 0; - boolean shouldBeSparseOut = false; - if(op.fn.isBinary()) { - // maybe it is good if this is a sparse output. - // evaluate if it is good - double est = evaluateSparsityMVCol(m1, m2, op, left); - shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, (long) (est * nRows * nCols)); - + // maybe it is good if this is a sparse output. + // evaluate if it is good + Pair tuple = evaluateSparsityMVCol(m1, m2, op, left); + double estSparsity = tuple.getKey(); + double estNnzPerRow = tuple.getValue(); + boolean shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, (long) (estSparsity * nRows * nCols)); + + // currently also jump into that case if estNnzPerRow == 0 + if(estNnzPerRow <= 2 && nCols <= 31 && op.fn instanceof ValueComparisonFunction){ + return k <= 1 ? binaryMVComparisonColSingleThreadCompressed(m1, m2, op, left) : + binaryMVComparisonColMultiCompressed(m1, m2, op, left); } MatrixBlock ret = new MatrixBlock(nRows, nCols, shouldBeSparseOut, -1).allocateBlock(); @@ -403,14 +451,143 @@ else if(nnz == 0) // all was 0 -> return empty. return ret; } + private static MatrixBlock binaryMVComparisonColSingleThreadCompressed(CompressedMatrixBlock m1, MatrixBlock m2, + BinaryOperator op, boolean left) { + final int nRows = m1.getNumRows(); + final int nCols = m1.getNumColumns(); + + // get indicators (one-hot-encoded comparison results) + BinaryMVColTaskCompressed task = new BinaryMVColTaskCompressed(m1, m2, 0, nRows, op, left); + long nnz = task.call(); + int[] indicators = task._ret; + + // map each unique indicator to an index + HashMapToInt hm = new HashMapToInt<>(nCols*3); + int[] colMap = new int[nRows]; + for(int i = 0; i < m1.getNumRows(); i++){ + int nextId = hm.size(); + int id = hm.putIfAbsentI(indicators[i], nextId); + colMap[i] = id == -1 ? nextId : id; + } + + // decode the unique indicator ints to SparseVectors + MatrixBlock outMb = getMCSRMatrixBlock(hm, nCols); + + // create compressed block + return getCompressedMatrixBlock(m1, colMap, hm.size(), outMb, nRows, nCols, nnz); + } + + private static void fillSparseBlockFromIndicatorFromIndicatorInt(int numCol, Integer indicator, Integer rix, SparseBlockMCSR out) { + ArrayList colIndices = new ArrayList<>(8); + for (int c = numCol - 1; c >= 0; c--) { + if(indicator <= 0) + break; + if(indicator % 2 == 1){ + colIndices.add(c); + } + indicator = indicator >> 1; + } + SparseRow row = null; + if(colIndices.size() > 1){ + double[] vals = new double[colIndices.size()]; + Arrays.fill(vals, 1); + int[] indices = new int[colIndices.size()]; + for (int i = 0, j = colIndices.size() - 1; i < colIndices.size(); i++, j--) + indices[i] = colIndices.get(j); + + row = new SparseRowVector(vals, indices); + } else if(colIndices.size() == 1){ + row = new SparseRowScalar(colIndices.get(0), 1.0); + } + out.set(rix, row, false); + } + + private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrixBlock m1, MatrixBlock m2, + BinaryOperator op, boolean left) throws Exception { + final int nRows = m1.getNumRows(); + final int nCols = m1.getNumColumns(); + final int k = op.getNumThreads(); + final int blkz = nRows / k; + + // get indicators (one-hot-encoded comparison results) + long nnz = 0; + final ExecutorService pool = CommonThreadPool.get(op.getNumThreads()); + try { + final ArrayList tasks = new ArrayList<>(); + for(int i = 0; i < nRows; i += blkz) { + tasks.add(new BinaryMVColTaskCompressed(m1, m2, i, Math.min(nRows, i + blkz), op, left)); + } + List> futures = pool.invokeAll(tasks); + HashMapToInt hm = new HashMapToInt<>(nCols*2); + int[] colMap = new int[nRows]; + + for(Future f : futures) + nnz += f.get(); + + // map each unique indicator to an index + mergeMVColTaskResults(tasks, blkz, hm, colMap); + + // decode the unique indicator ints to SparseVectors + MatrixBlock outMb = getMCSRMatrixBlock(hm, nCols); + + // create compressed block + return getCompressedMatrixBlock(m1, colMap, hm.size(), outMb, nRows, nCols, nnz); + } + finally { + pool.shutdown(); + } + + } + + private static void mergeMVColTaskResults(ArrayList tasks, int blkz, HashMapToInt hm, int[] colMap) { + + for(int j = 0; j < tasks.size(); j++) { + int[] indicators = tasks.get(j)._ret; + int offset = j* blkz; + + final int remainders = indicators.length % 8; + final int endVecLen = indicators.length - remainders; + for (int i = 0; i < endVecLen; i+= 8) { + colMap[offset + i] = hm.putIfAbsentReturnVal(indicators[i], hm.size()); + colMap[offset + i + 1] = hm.putIfAbsentReturnVal(indicators[i + 1], hm.size()); + colMap[offset + i + 2] = hm.putIfAbsentReturnVal(indicators[i + 2], hm.size()); + colMap[offset + i + 3] = hm.putIfAbsentReturnVal(indicators[i + 3], hm.size()); + colMap[offset + i + 4] = hm.putIfAbsentReturnVal(indicators[i + 4], hm.size()); + colMap[offset + i + 5] = hm.putIfAbsentReturnVal(indicators[i + 5], hm.size()); + colMap[offset + i + 6] = hm.putIfAbsentReturnVal(indicators[i + 6], hm.size()); + colMap[offset + i + 7] = hm.putIfAbsentReturnVal(indicators[i + 7], hm.size()); + + } + for (int i = 0; i < remainders; i++) { + colMap[offset + endVecLen + i] = hm.putIfAbsentReturnVal(indicators[endVecLen + i], hm.size()); + } + } + } + + + private static CompressedMatrixBlock getCompressedMatrixBlock(CompressedMatrixBlock m1, int[] colMap, + int mapSize, MatrixBlock outMb, int nRows, int nCols, long nnz) { + final IColIndex i = ColIndexFactory.create(0, m1.getNumColumns()); + final AMapToData map = MapToFactory.create(m1.getNumRows(), colMap, mapSize); + final AColGroup rgroup = ColGroupDDC.create(i, MatrixBlockDictionary.create(outMb), map, null); + final ArrayList groups = new ArrayList<>(1); + groups.add(rgroup); + return new CompressedMatrixBlock(nRows, nCols, nnz, false, groups); + } + + private static MatrixBlock getMCSRMatrixBlock(HashMapToInt hm, int nCols) { + // decode the unique indicator ints to SparseVectors + SparseBlockMCSR out = new SparseBlockMCSR(hm.size()); + hm.forEach((indicator, rix) -> + fillSparseBlockFromIndicatorFromIndicatorInt(nCols, indicator, rix, out)); + return new MatrixBlock(hm.size(), nCols, -1, out); + } + private static long binaryMVColSingleThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left, MatrixBlock ret) { final int nRows = m1.getNumRows(); long nnz = 0; - if(left) - nnz += new BinaryMVColLeftTaskDense(m1, m2, ret, 0, nRows, op).call(); - else - nnz += new BinaryMVColTaskDense(m1, m2, ret, 0, nRows, op).call(); + nnz += new BinaryMVColTaskDense(m1, m2, ret, 0, nRows, op, left).call(); return nnz; } @@ -418,10 +595,7 @@ private static long binaryMVColSingleThreadSparse(CompressedMatrixBlock m1, Matr boolean left, MatrixBlock ret) { final int nRows = m1.getNumRows(); long nnz = 0; - if(left) - throw new NotImplementedException(); - else - nnz += new BinaryMVColTaskSparse(m1, m2, ret, 0, nRows, op).call(); + nnz += new BinaryMVColTaskSparse(m1, m2, ret, 0, nRows, op, left).call(); return nnz; } @@ -435,10 +609,7 @@ private static long binaryMVColMultiThreadDense(CompressedMatrixBlock m1, Matrix try { final ArrayList> tasks = new ArrayList<>(); for(int i = 0; i < nRows; i += blkz) { - if(left) - tasks.add(new BinaryMVColLeftTaskDense(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); - else - tasks.add(new BinaryMVColTaskDense(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); + tasks.add(new BinaryMVColTaskDense(m1, m2, ret, i, Math.min(nRows, i + blkz), op, left)); } for(Future f : pool.invokeAll(tasks)) nnz += f.get(); @@ -459,10 +630,7 @@ private static long binaryMVColMultiThreadSparse(CompressedMatrixBlock m1, Matri try { final ArrayList> tasks = new ArrayList<>(); for(int i = 0; i < nRows; i += blkz) { - if(left) - throw new NotImplementedException(); - else - tasks.add(new BinaryMVColTaskSparse(m1, m2, ret, i, Math.min(nRows, i + blkz), op)); + tasks.add(new BinaryMVColTaskSparse(m1, m2, ret, i, Math.min(nRows, i + blkz), op, left)); } for(Future f : pool.invokeAll(tasks)) nnz += f.get(); @@ -543,6 +711,97 @@ private static CompressedMatrixBlock morph(CompressedMatrixBlock m) { return m; } + private static class BinaryMVColTaskCompressed implements Callable { + private final int _rl; + private final int _ru; + private final CompressedMatrixBlock _m1; + private final MatrixBlock _m2; + private final int[] _ret; + private final BinaryOperator _op; + private final ValueComparisonFunction _compFn; + private final boolean _left; + + private MatrixBlock tmp; + + protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, int rl, int ru, + BinaryOperator op, boolean left) { + _m1 = m1; + _m2 = m2; + _op = op; + _rl = rl; + _ru = ru; + _ret = new int[ru - rl]; + _compFn = (ValueComparisonFunction) op.fn; + _left = left; + } + + @Override + public Long call() { + tmp = allocateTempUncompressedBlock(_m1.getNumColumns()); + final int _blklen = tmp.getNumRows(); + final List groups = _m1.getColGroups(); + final AIterator[] its = getIterators(groups, _rl); + long nnz = 0; + + if(!_left) + for (int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen){ + int ru = Math.min(rl + _blklen, _ru); + decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); + nnz += processDense(rl, ru, retIxOff); + tmp.reset(); + } + else + for (int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen){ + int ru = Math.min(rl + _blklen, _ru); + decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); + nnz += processDenseLeft(rl, ru, retIxOff); + tmp.reset(); + } + + return nnz; + } + + private final long processDense(final int rl, final int ru, final int retIxOffset) { + final int nCol = _m1.getNumColumns(); + final double[] _tmpDense = tmp.getDenseBlockValues(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + long nnz = 0; + for(int row = rl, retIx = retIxOffset; row < ru; row++, retIx++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl) * nCol; + int indicatorVector = 0; + for(int col = 0; col < nCol; col++) { + indicatorVector = indicatorVector << 1; + int indicator = _compFn.compare(_tmpDense[tmpOff + col], vr) ? 1 : 0; + indicatorVector += indicator; + nnz += indicator; + } + _ret[retIx] = indicatorVector; + } + return nnz; + } + + private final long processDenseLeft(final int rl, final int ru, final int retIxOffset) { + final int nCol = _m1.getNumColumns(); + final double[] _tmpDense = tmp.getDenseBlockValues(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + long nnz = 0; + for(int row = rl, retIx = retIxOffset; row < ru; row++, retIx++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl) * nCol; + int indicatorVector = 0; + for(int col = 0; col < nCol; col++) { + indicatorVector = indicatorVector << 1; + int indicator = _compFn.compare(vr, _tmpDense[tmpOff + col]) ? 1 : 0; + indicatorVector += indicator; + nnz += indicator; + } + _ret[retIx] = indicatorVector; + } + return nnz; + } + } + private static class BinaryMVColTaskDense implements Callable { private final int _rl; private final int _ru; @@ -550,15 +809,17 @@ private static class BinaryMVColTaskDense implements Callable { private final MatrixBlock _m2; private final MatrixBlock _ret; private final BinaryOperator _op; + private boolean _left; protected BinaryMVColTaskDense(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, - BinaryOperator op) { + BinaryOperator op, boolean left) { _m1 = m1; _m2 = m2; _ret = ret; _op = op; _rl = rl; _ru = ru; + _left = left; } @Override @@ -568,8 +829,12 @@ public Long call() { final AIterator[] its = getIterators(groups, _rl); - for(int r = _rl; r < _ru; r += _blklen) - processBlock(r, Math.min(r + _blklen, _ru), groups, its); + if(!_left) + for(int r = _rl; r < _ru; r += _blklen) + processBlock(r, Math.min(r + _blklen, _ru), groups, its); + else + for(int r = _rl; r < _ru; r += _blklen) + processBlockLeft(r, Math.min(r + _blklen, _ru), groups, its); return _ret.recomputeNonZeros(_rl, _ru - 1); } @@ -581,6 +846,13 @@ private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + // unsafe decompress, since we count nonzeros afterwards. + final DenseBlock db = _ret.getDenseBlock(); + decompressToSubBlock(rl, ru, db, groups, its); + processGenericDenseLeft(rl, ru); + } + private final void processGenericDense(final int rl, final int ru) { final int ncol = _m1.getNumColumns(); final DenseBlock rd = _ret.getDenseBlock(); @@ -594,11 +866,29 @@ private final void processGenericDense(final int rl, final int ru) { } } + private final void processGenericDenseLeft(final int rl, final int ru) { + final int ncol = _m1.getNumColumns(); + final DenseBlock rd = _ret.getDenseBlock(); + // m2 is a vector therefore guaranteed continuous. + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double[] retDense = rd.values(row); + final int posR = rd.pos(row); + final double vr = _m2Dense[row]; + processRowLeft(ncol, retDense, posR, vr); + } + } + private void processRow(final int ncol, final double[] ret, final int posR, final double vr) { for(int col = 0; col < ncol; col++) ret[posR + col] = _op.fn.execute(ret[posR + col], vr); } + private void processRowLeft(final int ncol, final double[] ret, final int posR, final double vr) { + for(int col = 0; col < ncol; col++) + ret[posR + col] = _op.fn.execute(vr,ret[posR + col]); + } + } private static class BinaryMVColTaskSparse implements Callable { @@ -611,27 +901,31 @@ private static class BinaryMVColTaskSparse implements Callable { private MatrixBlock tmp; + private boolean _left; + protected BinaryMVColTaskSparse(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, - BinaryOperator op) { + BinaryOperator op, boolean left) { _m1 = m1; _m2 = m2; _ret = ret; _op = op; _rl = rl; _ru = ru; + _left = left; } @Override public Long call() { - final int _blklen = Math.max(16384 / _ret.getNumColumns(), 64); + tmp = allocateTempUncompressedBlock(_m1.getNumColumns()); + final int _blklen = tmp.getNumRows(); final List groups = _m1.getColGroups(); final AIterator[] its = getIterators(groups, _rl); - tmp = new MatrixBlock(_blklen, _m1.getNumColumns(), false); - tmp.allocateBlock(); - - for(int r = _rl; r < _ru; r += _blklen) - processBlock(r, Math.min(r + _blklen, _ru), groups, its); - + if(!_left) + for(int r = _rl; r < _ru; r += _blklen) + processBlock(r, Math.min(r + _blklen, _ru), groups, its); + else + for(int r = _rl; r < _ru; r += _blklen) + processBlockLeft(r, Math.min(r + _blklen, _ru), groups, its); return _ret.recomputeNonZeros(_rl, _ru - 1); } @@ -641,6 +935,12 @@ private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); + processDenseLeft(rl, ru); + tmp.reset(); + } + private final void processDense(final int rl, final int ru) { final int nCol = _m1.getNumColumns(); final SparseBlock sb = _ret.getSparseBlock(); @@ -654,6 +954,26 @@ private final void processDense(final int rl, final int ru) { } } + + private final void processDenseLeft(final int rl, final int ru) { + final int nCol = _m1.getNumColumns(); + final SparseBlock sb = _ret.getSparseBlock(); + final double[] _tmpDense = tmp.getDenseBlockValues(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl) * nCol; + for(int col = 0; col < nCol; col++) + sb.append(row, col, _op.fn.execute(vr, _tmpDense[tmpOff + col])); + + } + } + } + + private static MatrixBlock allocateTempUncompressedBlock(int cols) { + MatrixBlock out = new MatrixBlock(Math.max(DECOMPRESSION_BLEN / cols, 64), cols, false); + out.allocateBlock(); + return out; } private static class BinaryMMTask implements Callable { @@ -804,48 +1124,6 @@ private final void processRightDense(final int rl, final int ru) { } } - private static class BinaryMVColLeftTaskDense implements Callable { - private final int _rl; - private final int _ru; - private final CompressedMatrixBlock _m1; - private final MatrixBlock _m2; - private final MatrixBlock _ret; - private final BinaryOperator _op; - - protected BinaryMVColLeftTaskDense(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, - BinaryOperator op) { - _m1 = m1; - _m2 = m2; - _ret = ret; - _op = op; - _rl = rl; - _ru = ru; - } - - @Override - public Long call() { - for(AColGroup g : _m1.getColGroups()) - g.decompressToDenseBlock(_ret.getDenseBlock(), _rl, _ru); - - // m2 is never sparse or empty. always dense here. - final int ncol = _m1.getNumColumns(); - int offset = _rl * ncol; - double[] _retDense = _ret.getDenseBlockValues(); - double[] _m2Dense = _m2.getDenseBlockValues(); - for(int row = _rl; row < _ru; row++) { - double vr = _m2Dense[row]; - for(int col = 0; col < ncol; col++) { - double v = _op.fn.execute(vr, _retDense[offset]); - _retDense[offset] = v; - offset++; - } - } - - return _ret.recomputeNonZeros(_rl, _ru - 1); - - } - } - private static abstract class BinaryMVRowTask implements Callable { protected final AColGroup _group; protected final double[] _v; @@ -931,8 +1209,8 @@ protected static AIterator[] getIterators(final List groups, final in return its; } - private static double evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, - boolean left) { + private static Pair evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left) { final List groups = m1.getColGroups(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); @@ -950,26 +1228,36 @@ private static double evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBloc decompressToDense(groups, sampleRow, sampleCol, dv); int nnz = 0; + int[] nnzPerRow = new int[sampleRow]; // m2v guaranteed to be dense and not empty. // if empty then we defaulted to scalar operations. if(left) { for(int r = 0; r < sampleRow; r++) { final double m = m2v[r]; final int off = r * sampleCol; - for(int c = 0; c < sampleCol; c++) - nnz += op.fn.execute(m, dv[off + c]) != 0 ? 1 : 0; + for(int c = 0; c < sampleCol; c++) { + int outVal = op.fn.execute(m, dv[off + c]) != 0 ? 1 : 0; + nnz += outVal; + nnzPerRow[r] += outVal; + } } } else { for(int r = 0; r < sampleRow; r++) { final double m = m2v[r]; final int off = r * sampleCol; - for(int c = 0; c < sampleCol; c++) - nnz += op.fn.execute(dv[off + c], m) != 0 ? 1 : 0; + for(int c = 0; c < sampleCol; c++){ + int outVal = op.fn.execute(dv[off + c], m) != 0 ? 1 : 0; + nnz += outVal; + nnzPerRow[r] += outVal; + } } } - - return (double) nnz / (sampleNCells); + double sum = 0; + for(int i = 0; i < sampleRow; i++) { + sum += nnzPerRow[i]; + } + return new Pair<>((double) nnz / (sampleNCells), sum / sampleRow); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java index 2092440ed18..a7b8591e647 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java @@ -26,6 +26,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.Stack; @@ -38,6 +39,8 @@ import org.apache.sysds.common.Types.OpOpData; import org.apache.sysds.common.Types.ParamBuiltinOp; import org.apache.sysds.common.Types.ReOrgOp; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.AggBinaryOp; import org.apache.sysds.hops.AggUnaryOp; import org.apache.sysds.hops.BinaryOp; @@ -81,9 +84,15 @@ public class WorkloadAnalyzer { private final DMLProgram prog; private final Map treeLookup; private final Stack stack; + private Stack lineage = new Stack<>(); public static Map getAllCandidateWorkloads(DMLProgram prog) { // extract all compression candidates from program (in program order) + String configValue = ConfigurationManager.getDMLConfig() + .getTextValue(DMLConfig.COMPRESSED_LINALG_INTERMEDIATE); + // if set update it, otherwise keep it as set before + ALLOW_INTERMEDIATE_CANDIDATES = configValue != null && Objects.equals(configValue.toUpperCase(), "TRUE") || + configValue == null && ALLOW_INTERMEDIATE_CANDIDATES; List candidates = getCandidates(prog); // for each candidate, create pruned workload tree @@ -115,6 +124,7 @@ private WorkloadAnalyzer(DMLProgram prog) { this.overlapping = new HashSet<>(); this.treeLookup = new HashMap<>(); this.stack = new Stack<>(); + this.lineage = new Stack<>(); } private WorkloadAnalyzer(DMLProgram prog, Set compressed, HashMap transientCompressed, @@ -235,6 +245,7 @@ private static void getCandidates(Hop hop, DMLProgram prog, List cands, Set private void createWorkloadTreeNodes(AWTreeNode n, StatementBlock sb, DMLProgram prog, Set fStack) { WTreeNode node; + lineage.add(sb); if(sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock) sb; FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); @@ -291,7 +302,7 @@ else if(sb instanceof ForStatementBlock) { // incl parfor if(hop instanceof FunctionOp) { FunctionOp fop = (FunctionOp) hop; if(HopRewriteUtils.isTransformEncode(fop)) - return; + break; else if(!fStack.contains(fop.getFunctionKey())) { fStack.add(fop.getFunctionKey()); FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionKey()); @@ -323,9 +334,11 @@ else if(!fStack.contains(fop.getFunctionKey())) { } } } + lineage.pop(); return; } n.addChild(node); + lineage.pop(); } private void createStack(Hop hop) { @@ -396,7 +409,22 @@ else if(hop instanceof AggUnaryOp) { return; } else { - o = new OpNormal(hop, false); + boolean compressedOut = false; + Hop parentHop = hop.getInput(0); + if(HopRewriteUtils.isBinary(parentHop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, + OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL)){ + Hop leftIn = parentHop.getInput(0); + Hop rightIn = parentHop.getInput(1); + // input ops might be not in the current statement block -> check for transient reads + if(HopRewriteUtils.isAggUnaryOp(leftIn, AggOp.MIN, AggOp.MAX) || + HopRewriteUtils.isAggUnaryOp(rightIn, AggOp.MIN, AggOp.MAX) || + checkTransientRead(hop, leftIn) || + checkTransientRead(hop, rightIn) + ) + compressedOut = true; + + } + o = new OpNormal(hop, compressedOut); } } else if(hop instanceof UnaryOp) { @@ -477,9 +505,17 @@ else if(ol) { if(!HopRewriteUtils.isBinarySparseSafe(hop)) o.setDensifying(); + } else if(HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ) { + Hop leftIn = hop.getInput(0); + Hop rightIn = hop.getInput(1); + if(HopRewriteUtils.isBinary(hop, OpOp2.DIV) && rightIn instanceof AggUnaryOp && leftIn == rightIn.getInput(0)){ + o = new OpNormal(hop, true); + } else { + setDecompressionOnAllInputs(hop, parent); + return; + } } else if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) || - HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) || HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) { setDecompressionOnAllInputs(hop, parent); return; @@ -623,6 +659,40 @@ else if(hop instanceof AggUnaryOp) { } } + private boolean checkTransientRead(Hop hop, Hop input) { + // op is not in current statement block + if(HopRewriteUtils.isData(input, OpOpData.TRANSIENTREAD)){ + String varName = input.getName(); + StatementBlock csb = lineage.peek(); + StatementBlock parentStatement = lineage.get(lineage.size() -2); + + if(parentStatement instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) parentStatement; + WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); + ArrayList stmts = wstmt.getBody(); + boolean foundCurrent = false; + StatementBlock sb; + + // traverse statement blocks in reverse to find the statement block, which came before the current + // if we iterate in default order, we might find an earlier updated version of the current variable + for (int i = stmts.size()-1; i >= 0; i--) { + sb = stmts.get(i); + if(foundCurrent && sb.variablesUpdated().containsVariable(varName)) { + for(Hop cand : sb.getHops()){ + if(HopRewriteUtils.isData(cand, OpOpData.TRANSIENTWRITE) && cand.getName().equals(varName) + && HopRewriteUtils.isAggUnaryOp( cand.getInput(0), AggOp.MIN, AggOp.MAX)){ + return true; + } + } + } else if(sb == csb){ + foundCurrent = true; + } + } + } + } + return false; + } + private boolean isCompressed(Hop hop) { return compressed.contains(hop.getHopID()); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index fa6a86fce49..b26695e5797 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -152,6 +152,28 @@ public int putIfAbsentI(K key, int value) { } + public int putIfAbsentReturnVal(K key, int value) { + + if(key == null) { + if(nullV == -1) { + size++; + nullV = value; + return -1; + } + else + return nullV; + } + else { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + } + + } + private int putIfAbsentBucket(int ix, K key, int value) { Node b = buckets[ix]; while(true) { @@ -167,6 +189,21 @@ private int putIfAbsentBucket(int ix, K key, int value) { } } + private int putIfAbsentBucketReturnval(int ix, K key, int value) { + Node b = buckets[ix]; + while(true) { + if(b.key.equals(key)) + return b.value; + if(b.next == null) { + b.setNext(new Node<>(key, value, null)); + size++; + resize(); + return value; + } + b = b.next; + } + } + public int putI(K key, int value) { if(key == null) { int tmp = nullV; @@ -191,6 +228,12 @@ private int createBucket(int ix, K key, int value) { return -1; } + private int createBucketReturnVal(int ix, K key, int value) { + buckets[ix] = new Node(key, value, null); + size++; + return value; + } + private int addToBucket(int ix, K key, int value) { Node b = buckets[ix]; while(true) { diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpCustomTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpCustomTest.java index c008a7d1254..1ce05fab616 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpCustomTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpCustomTest.java @@ -22,9 +22,13 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; +import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.CompressionStatistics; import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp; +import org.apache.sysds.runtime.functionobjects.GreaterThanEquals; +import org.apache.sysds.runtime.functionobjects.LessThanEquals; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; @@ -47,6 +51,35 @@ public void notColVector() { TestUtils.compareMatricesBitAvgDistance(new MatrixBlock(10, 10, 1.3), cRet2, 0, 0, op.toString()); } + @Test + public void twoHotEncodedOutput() { + BinaryOperator op = new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject(), 2); + BinaryOperator op2 = new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject()); + BinaryOperator opLeft = new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), 2); + BinaryOperator opLeft2 = new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject()); + + MatrixBlock cDense = new MatrixBlock(30, 30, 2.0); + for (int i = 0; i < 30; i++) { + cDense.set(i,0, 1); + } + cDense.set(0,1, 1); + Pair pair = CompressedMatrixBlockFactory.compress(cDense, 1); + CompressedMatrixBlock c = (CompressedMatrixBlock) pair.getKey(); + MatrixBlock c2 = new MatrixBlock(30, 1, 1.0); + CompressedMatrixBlock spy = spy(c); + when(spy.getCachedDecompressed()).thenReturn(null); + + MatrixBlock cRet = CLALibBinaryCellOp.binaryOperationsRight(op, spy, c2); + MatrixBlock cRet2 = CLALibBinaryCellOp.binaryOperationsRight(op2, spy, c2); + TestUtils.compareMatricesBitAvgDistance(cRet, cRet2, 0, 0, op.toString()); + + MatrixBlock cRetleft = CLALibBinaryCellOp.binaryOperationsLeft(opLeft, spy, c2); + MatrixBlock cRetleft2 = CLALibBinaryCellOp.binaryOperationsLeft(opLeft2, spy, c2); + TestUtils.compareMatricesBitAvgDistance(cRetleft, cRetleft2, 0, 0, op.toString()); + + TestUtils.compareMatricesBitAvgDistance(cRet, cRetleft, 0, 0, op.toString()); + } + @Test public void notColVectorEmptyReturn() { BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject(), 2); diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpTest.java index 69305cf5b24..11df5b104ed 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibBinaryCellOpTest.java @@ -35,11 +35,33 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp; +import org.apache.sysds.runtime.functionobjects.And; +import org.apache.sysds.runtime.functionobjects.BitwAnd; +import org.apache.sysds.runtime.functionobjects.BitwOr; +import org.apache.sysds.runtime.functionobjects.BitwShiftL; +import org.apache.sysds.runtime.functionobjects.BitwShiftR; +import org.apache.sysds.runtime.functionobjects.BitwXor; +import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Divide; +import org.apache.sysds.runtime.functionobjects.Equals; +import org.apache.sysds.runtime.functionobjects.GreaterThan; +import org.apache.sysds.runtime.functionobjects.GreaterThanEquals; +import org.apache.sysds.runtime.functionobjects.IntegerDivide; +import org.apache.sysds.runtime.functionobjects.LessThan; +import org.apache.sysds.runtime.functionobjects.LessThanEquals; +import org.apache.sysds.runtime.functionobjects.Minus; +import org.apache.sysds.runtime.functionobjects.Minus1Multiply; +import org.apache.sysds.runtime.functionobjects.MinusMultiply; +import org.apache.sysds.runtime.functionobjects.MinusNz; +import org.apache.sysds.runtime.functionobjects.Modulus; import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.NotEquals; +import org.apache.sysds.runtime.functionobjects.Or; import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.PlusMultiply; import org.apache.sysds.runtime.functionobjects.Power; import org.apache.sysds.runtime.functionobjects.ValueFunction; +import org.apache.sysds.runtime.functionobjects.Xor; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; @@ -54,43 +76,45 @@ public class CLALibBinaryCellOpTest { protected static final Log LOG = LogFactory.getLog(CombineGroupsTest.class.getName()); public final static ValueFunction[] vf = {// - // (Plus.getPlusFnObject()), // - // (Minus.getMinusFnObject()), // + (Plus.getPlusFnObject()), // + (Minus.getMinusFnObject()), // Divide.getDivideFnObject(), // - // (Or.getOrFnObject()), // - // (LessThan.getLessThanFnObject()), // - // (LessThanEquals.getLessThanEqualsFnObject()), // - // (GreaterThan.getGreaterThanFnObject()), // - // (GreaterThanEquals.getGreaterThanEqualsFnObject()), // - // (Multiply.getMultiplyFnObject()), // - // (Modulus.getFnObject()), // - // (IntegerDivide.getFnObject()), // - // (Equals.getEqualsFnObject()), // - // (NotEquals.getNotEqualsFnObject()), // - // (And.getAndFnObject()), // - // (Xor.getXorFnObject()), // - // (BitwAnd.getBitwAndFnObject()), // - // (BitwOr.getBitwOrFnObject()), // - // (BitwXor.getBitwXorFnObject()), // - // (BitwShiftL.getBitwShiftLFnObject()), // - // (BitwShiftR.getBitwShiftRFnObject()), // - // (Power.getPowerFnObject()), // - // (MinusNz.getMinusNzFnObject()), // - // (new PlusMultiply(32)), // - // (new PlusMultiply(2)), // - // (new PlusMultiply(0)), // - // (new MinusMultiply(32)), // - // Minus1Multiply.getMinus1MultiplyFnObject(), - // // // Builtin - // (Builtin.getBuiltinFnObject(BuiltinCode.MIN)), // - // (Builtin.getBuiltinFnObject(BuiltinCode.MAX)), // - // (Builtin.getBuiltinFnObject(BuiltinCode.LOG)), // - // (Builtin.getBuiltinFnObject(BuiltinCode.LOG_NZ)), // - // (Builtin.getBuiltinFnObject(BuiltinCode.MAXINDEX)), // - // (Builtin.getBuiltinFnObject(BuiltinCode.MININDEX)), // - // (Builtin.getBuiltinFnObject(BuiltinCode.CUMMAX)), // - // (Builtin.getBuiltinFnObject(BuiltinCode.CUMMIN)),// - }; + (Or.getOrFnObject()), // + (LessThan.getLessThanFnObject()), // + (LessThanEquals.getLessThanEqualsFnObject()), // + (GreaterThan.getGreaterThanFnObject()), // + (GreaterThanEquals.getGreaterThanEqualsFnObject()), // + (Multiply.getMultiplyFnObject()), // + (Modulus.getFnObject()), // + (IntegerDivide.getFnObject()), // + (Equals.getEqualsFnObject()), // + (NotEquals.getNotEqualsFnObject()), // + (And.getAndFnObject()), // + (Xor.getXorFnObject()), // + (BitwAnd.getBitwAndFnObject()), // + (BitwOr.getBitwOrFnObject()), // + (BitwXor.getBitwXorFnObject()), // + (BitwShiftL.getBitwShiftLFnObject()), // + (BitwShiftR.getBitwShiftRFnObject()), // + // TODO: power fails currently in some cases + //(Power.getPowerFnObject()), // + (MinusNz.getMinusNzFnObject()), // + (new PlusMultiply(32)), // + (new PlusMultiply(2)), // + (new PlusMultiply(0)), // + (new MinusMultiply(32)), // + Minus1Multiply.getMinus1MultiplyFnObject(), + + // // Builtin + (Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MIN)), // + (Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MAX)), // + (Builtin.getBuiltinFnObject(Builtin.BuiltinCode.LOG)), // + (Builtin.getBuiltinFnObject(Builtin.BuiltinCode.LOG_NZ)), // + (Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MAXINDEX)), // + (Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MININDEX)), // + (Builtin.getBuiltinFnObject(Builtin.BuiltinCode.CUMMAX)), // + (Builtin.getBuiltinFnObject(Builtin.BuiltinCode.CUMMIN)),// + }; private final MatrixBlock mb; private final CompressedMatrixBlock cmb; @@ -465,6 +489,13 @@ public void binLeftMcV() { execL(op, mb, cmb, mcv2); } + @Test(expected = Exception.class) + public void binLeftMcV_noCache() { + CompressedMatrixBlock spy = spy(cmb); + when(spy.getCachedDecompressed()).thenReturn(null); + execL(op, mb, spy, mcv2); + } + @Test(expected = Exception.class) public void binLeftMS() { if(mScalar2 == null) diff --git a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java index f09fcf456e7..0ee5b7d4616 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java @@ -156,6 +156,16 @@ public void testKmeansUnsuccessfulCP() { runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1, false, 10); } + @Test + public void testKmeansSuccessfulIntermediateCP() { + runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 42, true, 25); + } + + @Test + public void testKmeansUnsuccessfulIntermediateCP() { + runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 27, true, 10); + } + private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates){ runWorkloadAnalysisTest(testname, mode, compressionCount, intermediates, -1); }