Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/conf/DMLConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public class DMLConfig
public static final String PARALLEL_TOKENIZE = "sysds.parallel.tokenize";
public static final String PARALLEL_TOKENIZE_NUM_BLOCKS = "sysds.parallel.tokenize.numBlocks";
public static final String COMPRESSED_LINALG = "sysds.compressed.linalg";
public static final String COMPRESSED_LINALG_INTERMEDIATE = "sysds.compressed.linalg.intermediate";
public static final String COMPRESSED_LOSSY = "sysds.compressed.lossy";
public static final String COMPRESSED_VALID_COMPRESSIONS = "sysds.compressed.valid.compressions";
public static final String COMPRESSED_OVERLAPPING = "sysds.compressed.overlapping";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,12 @@ public static boolean satisfiesAggressiveCompressionCondition(Hop hop) {
satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE)
&& hop.getInput(0).getDataType().isMatrix()
&& hop.getInput(1).getDataType().isMatrix();
satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD) && !hop.isScalar();
satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD);
satisfies |= HopRewriteUtils.isUnary(hop, OpOp1.ROUND, OpOp1.FLOOR, OpOp1.NOT, OpOp1.CEIL);
satisfies |= HopRewriteUtils.isBinary(hop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS,
OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.AND, OpOp2.OR, OpOp2.MODULUS);
satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE);
satisfies &= !hop.isScalar();
}
if(LOG.isDebugEnabled() && satisfies)
LOG.debug("Operation Satisfies: " + hop);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ private void classifyPhase() {
// final int nRows = mb.getNumRows();
final int nCols = mb.getNumColumns();
// Assume the scaling of cocoding is at maximum square root good relative to number of columns.
final double scale = Math.sqrt(nCols);
final double scale = mb instanceof CompressedMatrixBlock &&
((CompressedMatrixBlock) mb).getColGroups().size() == 1 ? 1 : Math.sqrt(nCols);
final double threshold = _stats.estimatedCostCols / scale;

if(threshold < _stats.originalCost *
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.Stack;

Expand All @@ -38,6 +39,8 @@
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
Expand Down Expand Up @@ -81,9 +84,15 @@ public class WorkloadAnalyzer {
private final DMLProgram prog;
private final Map<Long, Op> treeLookup;
private final Stack<Hop> stack;
private Stack<StatementBlock> lineage = new Stack<>();

public static Map<Long, WTreeRoot> getAllCandidateWorkloads(DMLProgram prog) {
// extract all compression candidates from program (in program order)
String configValue = ConfigurationManager.getDMLConfig()
.getTextValue(DMLConfig.COMPRESSED_LINALG_INTERMEDIATE);
// if set update it, otherwise keep it as set before
ALLOW_INTERMEDIATE_CANDIDATES = configValue != null && Objects.equals(configValue.toUpperCase(), "TRUE") ||
configValue == null && ALLOW_INTERMEDIATE_CANDIDATES;
List<Hop> candidates = getCandidates(prog);

// for each candidate, create pruned workload tree
Expand Down Expand Up @@ -115,6 +124,7 @@ private WorkloadAnalyzer(DMLProgram prog) {
this.overlapping = new HashSet<>();
this.treeLookup = new HashMap<>();
this.stack = new Stack<>();
this.lineage = new Stack<>();
}

private WorkloadAnalyzer(DMLProgram prog, Set<Long> compressed, HashMap<String, Long> transientCompressed,
Expand Down Expand Up @@ -235,6 +245,7 @@ private static void getCandidates(Hop hop, DMLProgram prog, List<Hop> cands, Set

private void createWorkloadTreeNodes(AWTreeNode n, StatementBlock sb, DMLProgram prog, Set<String> fStack) {
WTreeNode node;
lineage.add(sb);
if(sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
Expand Down Expand Up @@ -291,7 +302,7 @@ else if(sb instanceof ForStatementBlock) { // incl parfor
if(hop instanceof FunctionOp) {
FunctionOp fop = (FunctionOp) hop;
if(HopRewriteUtils.isTransformEncode(fop))
return;
break;
else if(!fStack.contains(fop.getFunctionKey())) {
fStack.add(fop.getFunctionKey());
FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionKey());
Expand Down Expand Up @@ -323,9 +334,11 @@ else if(!fStack.contains(fop.getFunctionKey())) {
}
}
}
lineage.pop();
return;
}
n.addChild(node);
lineage.pop();
}

private void createStack(Hop hop) {
Expand Down Expand Up @@ -396,7 +409,22 @@ else if(hop instanceof AggUnaryOp) {
return;
}
else {
o = new OpNormal(hop, false);
boolean compressedOut = false;
Hop parentHop = hop.getInput(0);
if(HopRewriteUtils.isBinary(parentHop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS,
OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL)){
Hop leftIn = parentHop.getInput(0);
Hop rightIn = parentHop.getInput(1);
// input ops might be not in the current statement block -> check for transient reads
if(HopRewriteUtils.isAggUnaryOp(leftIn, AggOp.MIN, AggOp.MAX) ||
HopRewriteUtils.isAggUnaryOp(rightIn, AggOp.MIN, AggOp.MAX) ||
checkTransientRead(hop, leftIn) ||
checkTransientRead(hop, rightIn)
)
compressedOut = true;

}
o = new OpNormal(hop, compressedOut);
}
}
else if(hop instanceof UnaryOp) {
Expand Down Expand Up @@ -477,9 +505,17 @@ else if(ol) {
if(!HopRewriteUtils.isBinarySparseSafe(hop))
o.setDensifying();

} else if(HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ) {
Hop leftIn = hop.getInput(0);
Hop rightIn = hop.getInput(1);
if(HopRewriteUtils.isBinary(hop, OpOp2.DIV) && rightIn instanceof AggUnaryOp && leftIn == rightIn.getInput(0)){
o = new OpNormal(hop, true);
} else {
setDecompressionOnAllInputs(hop, parent);
return;
}
}
else if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
setDecompressionOnAllInputs(hop, parent);
return;
Expand Down Expand Up @@ -623,6 +659,40 @@ else if(hop instanceof AggUnaryOp) {
}
}

private boolean checkTransientRead(Hop hop, Hop input) {
// op is not in current statement block
if(HopRewriteUtils.isData(input, OpOpData.TRANSIENTREAD)){
String varName = input.getName();
StatementBlock csb = lineage.peek();
StatementBlock parentStatement = lineage.get(lineage.size() -2);

if(parentStatement instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) parentStatement;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
ArrayList<StatementBlock> stmts = wstmt.getBody();
boolean foundCurrent = false;
StatementBlock sb;

// traverse statement blocks in reverse to find the statement block, which came before the current
// if we iterate in default order, we might find an earlier updated version of the current variable
for (int i = stmts.size()-1; i >= 0; i--) {
sb = stmts.get(i);
if(foundCurrent && sb.variablesUpdated().containsVariable(varName)) {
for(Hop cand : sb.getHops()){
if(HopRewriteUtils.isData(cand, OpOpData.TRANSIENTWRITE) && cand.getName().equals(varName)
&& HopRewriteUtils.isAggUnaryOp( cand.getInput(0), AggOp.MIN, AggOp.MAX)){
return true;
}
}
} else if(sb == csb){
foundCurrent = true;
}
}
}
}
return false;
}

private boolean isCompressed(Hop hop) {
return compressed.contains(hop.getHopID());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,28 @@ public int putIfAbsentI(K key, int value) {

}

public int putIfAbsentReturnVal(K key, int value) {

if(key == null) {
if(nullV == -1) {
size++;
nullV = value;
return -1;
}
else
return nullV;
}
else {
final int ix = hash(key);
Node<K> b = buckets[ix];
if(b == null)
return createBucketReturnVal(ix, key, value);
else
return putIfAbsentBucketReturnval(ix, key, value);
}

}

private int putIfAbsentBucket(int ix, K key, int value) {
Node<K> b = buckets[ix];
while(true) {
Expand All @@ -167,6 +189,21 @@ private int putIfAbsentBucket(int ix, K key, int value) {
}
}

private int putIfAbsentBucketReturnval(int ix, K key, int value) {
Node<K> b = buckets[ix];
while(true) {
if(b.key.equals(key))
return b.value;
if(b.next == null) {
b.setNext(new Node<>(key, value, null));
size++;
resize();
return value;
}
b = b.next;
}
}

public int putI(K key, int value) {
if(key == null) {
int tmp = nullV;
Expand All @@ -191,6 +228,12 @@ private int createBucket(int ix, K key, int value) {
return -1;
}

private int createBucketReturnVal(int ix, K key, int value) {
buckets[ix] = new Node<K>(key, value, null);
size++;
return value;
}

private int addToBucket(int ix, K key, int value) {
Node<K> b = buckets[ix];
while(true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
import org.apache.sysds.runtime.compress.CompressionStatistics;
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
import org.apache.sysds.runtime.functionobjects.LessThanEquals;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
Expand All @@ -47,6 +51,35 @@ public void notColVector() {
TestUtils.compareMatricesBitAvgDistance(new MatrixBlock(10, 10, 1.3), cRet2, 0, 0, op.toString());
}

@Test
public void twoHotEncodedOutput() {
BinaryOperator op = new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject(), 2);
BinaryOperator op2 = new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject());
BinaryOperator opLeft = new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), 2);
BinaryOperator opLeft2 = new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject());

MatrixBlock cDense = new MatrixBlock(30, 30, 2.0);
for (int i = 0; i < 30; i++) {
cDense.set(i,0, 1);
}
cDense.set(0,1, 1);
Pair<MatrixBlock, CompressionStatistics> pair = CompressedMatrixBlockFactory.compress(cDense, 1);
CompressedMatrixBlock c = (CompressedMatrixBlock) pair.getKey();
MatrixBlock c2 = new MatrixBlock(30, 1, 1.0);
CompressedMatrixBlock spy = spy(c);
when(spy.getCachedDecompressed()).thenReturn(null);

MatrixBlock cRet = CLALibBinaryCellOp.binaryOperationsRight(op, spy, c2);
MatrixBlock cRet2 = CLALibBinaryCellOp.binaryOperationsRight(op2, spy, c2);
TestUtils.compareMatricesBitAvgDistance(cRet, cRet2, 0, 0, op.toString());

MatrixBlock cRetleft = CLALibBinaryCellOp.binaryOperationsLeft(opLeft, spy, c2);
MatrixBlock cRetleft2 = CLALibBinaryCellOp.binaryOperationsLeft(opLeft2, spy, c2);
TestUtils.compareMatricesBitAvgDistance(cRetleft, cRetleft2, 0, 0, op.toString());

TestUtils.compareMatricesBitAvgDistance(cRet, cRetleft, 0, 0, op.toString());
}

@Test
public void notColVectorEmptyReturn() {
BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject(), 2);
Expand Down
Loading
Loading