Skip to content

Commit e3b3eba

Browse files
committed
correct Optimizer functionality
1 parent b3d4d6f commit e3b3eba

1 file changed

Lines changed: 30 additions & 6 deletions

File tree

src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,24 +268,48 @@ private boolean getTemplateType(double sparsityEst, double scalarVal) {
268268
case VECT_DIV_SCALAR:
269269
case VECT_XOR_SCALAR:
270270
case VECT_BITWAND_SCALAR: return sparsityEst < 0.3;
271-
case VECT_GREATER_SCALAR:
272-
case VECT_GREATEREQUAL_SCALAR:
273-
case VECT_MIN_SCALAR: {
271+
case VECT_GREATER_SCALAR: {
274272
if(scalarVal != Double.NaN) {
275273
return _inputs.get(1).getDataType().isScalar() ? scalarVal >= 0 && sparsityEst < 0.2
276274
: _inputs.get(0).getDataType().isScalar() && scalarVal < 0 && sparsityEst < 0.2;
277275
} else
278276
return false;
279277
}
280-
case VECT_LESS_SCALAR:
281-
case VECT_LESSEQUAL_SCALAR:
282-
case VECT_MAX_SCALAR: {
278+
case VECT_GREATEREQUAL_SCALAR: {
279+
if(scalarVal != Double.NaN) {
280+
return _inputs.get(1).getDataType().isScalar() ? scalarVal > 0 && sparsityEst < 0.2
281+
: _inputs.get(0).getDataType().isScalar() && scalarVal <= 0 && sparsityEst < 0.2;
282+
} else
283+
return false;
284+
}
285+
case VECT_MIN_SCALAR: {
286+
if(scalarVal != Double.NaN) {
287+
return _inputs.get(1).getDataType().isScalar() ? scalarVal >= 0 && sparsityEst < 0.2
288+
: _inputs.get(0).getDataType().isScalar() && scalarVal >= 0 && sparsityEst < 0.2;
289+
} else
290+
return false;
291+
}
292+
case VECT_LESS_SCALAR: {
283293
if(scalarVal != Double.NaN) {
284294
return _inputs.get(1).getDataType().isScalar() ? scalarVal <= 0 && sparsityEst < 0.2
285295
: _inputs.get(0).getDataType().isScalar() && scalarVal > 0 && sparsityEst < 0.2;
286296
} else
287297
return false;
288298
}
299+
case VECT_LESSEQUAL_SCALAR: {
300+
if(scalarVal != Double.NaN) {
301+
return _inputs.get(1).getDataType().isScalar() ? scalarVal < 0 && sparsityEst < 0.2
302+
: _inputs.get(0).getDataType().isScalar() && scalarVal >= 0 && sparsityEst < 0.2;
303+
} else
304+
return false;
305+
}
306+
case VECT_MAX_SCALAR: {
307+
if(scalarVal != Double.NaN) {
308+
return _inputs.get(1).getDataType().isScalar() ? scalarVal <= 0 && sparsityEst < 0.2
309+
: _inputs.get(0).getDataType().isScalar() && scalarVal <= 0 && sparsityEst < 0.2;
310+
} else
311+
return false;
312+
}
289313
case VECT_POW_SCALAR:
290314
case VECT_EQUAL_SCALAR:{
291315
if(scalarVal != Double.NaN) {

0 commit comments

Comments
 (0)